Merge branch 'main' into fix/app-list-walk-nodes-graceful

This commit is contained in:
fisherOne1 2025-12-05 14:45:34 +08:00 committed by GitHub
commit 7166077230
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 4807 additions and 4189 deletions

View File

@ -106,7 +106,7 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check
run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
@ -18,6 +19,30 @@ from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]):
class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect(
console_ns.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
@console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site
if not site:
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
desc = payload.desc or ""
copy_right = payload.copyright or ""
privacy_policy = payload.privacy_policy or ""
custom_disclaimer = payload.custom_disclaimer or ""
else:
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
desc = site.description or payload.desc or ""
copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
if not recommended_app:
@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=args["language"],
category=args["category"],
position=args["position"],
language=payload.language,
category=payload.category,
position=payload.position,
)
db.session.add(recommended_app)
@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
recommended_app.language = payload.language
recommended_app.category = payload.category
recommended_app.position = payload.position
app.is_public = True

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -8,10 +10,21 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
)
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
args = parser.parse_args()
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)

View File

@ -1,7 +1,8 @@
from typing import Literal
from typing import Any, Literal
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
@ -21,22 +22,79 @@ from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect(
console_ns.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
@console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect(
console_ns.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
@console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200
@ -142,12 +184,7 @@ class AnnotationApi(Resource):
@console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
@console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -155,9 +192,10 @@ class AnnotationApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_id):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
keyword = args.keyword
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@ -173,18 +211,7 @@ class AnnotationApi(Resource):
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"CreateAnnotationRequest",
{
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -195,16 +222,9 @@ class AnnotationApi(Resource):
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
@setup_required
@ -256,13 +276,6 @@ class AnnotationExportApi(Resource):
return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation")
@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation
@setup_required

View File

@ -146,7 +146,14 @@ class AppApiStatusPayload(BaseModel):
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider")
tracing_provider: str | None = Field(default=None, description="Tracing provider")
@field_validator("tracing_provider")
@classmethod
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
if info.data.get("enabled") and not value:
raise ValueError("tracing_provider is required when enabled is True")
return value
def reg(cls: type[BaseModel]):
@ -324,10 +331,13 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
try:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
continue
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids

View File

@ -1,4 +1,5 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -61,7 +68,7 @@ class AppImportApi(Resource):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
args = parser.parse_args()
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with Session(db.engine) as session:
@ -70,15 +77,15 @@ class AppImportApi(Resource):
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
import_mode=args.mode,
yaml_content=args.yaml_content,
yaml_url=args.yaml_url,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
app_id=args.app_id,
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
@ -32,6 +33,27 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(
console_ns.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
@console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
def post(self, app_model: App):
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
payload = TextToSpeechPayload.model_validate(console_ns.payload)
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@ -159,9 +164,7 @@ class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
)
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
@console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
)
@ -172,12 +175,11 @@ class TextModesApi(Resource):
@account_initialization_required
def get(self, app_model):
try:
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args["language"],
language=args.language,
)
return response

View File

@ -1,7 +1,8 @@
import json
from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
description = args.get("description")
description = payload.description
if not description:
description = app_model.description or ""
server = AppMCPServer(
name=app_model.name,
description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False),
parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_tenant_id,
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
if not server:
raise NotFound()
description = args.get("description")
description = payload.description
if description is None:
pass
elif not description:
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
else:
server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = args["status"]
server.status = payload.status
db.session.commit()
return server

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required
from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource):
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
)
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config:
return {"has_not_configured": True}
return trace_config
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
)
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigIsExist()
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigNotExist()
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@ -1,4 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args():
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource):
@console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@ -89,7 +73,7 @@ class AppSite(Resource):
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = parse_app_site_args()
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
@ -113,7 +97,7 @@ class AppSite(Resource):
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
value = getattr(args, attr_name)
if value is not None:
setattr(site, attr_name, value)

View File

@ -1,10 +1,11 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar
from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment):
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(_create_pagination_parser())
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"})
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# fetch draft workflow by app_model
workflow_service = WorkflowService()
@ -323,15 +328,7 @@ class VariableApi(Resource):
@console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable")
@console_ns.expect(
console_ns.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@ -358,16 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
@ -375,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
new_name = args_model.name
raw_value = args_model.value
if new_name is None and raw_value is None:
return variable

View File

@ -1,28 +1,53 @@
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from libs.helper import EmailStr, extract_remote_ip, timezone
from models import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(active_check_parser)
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.response(
200,
"Success",
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
),
)
def get(self):
args = active_check_parser.parse_args()
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args["workspace_id"]
reg_email = args["email"]
token = args["token"]
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
@console_ns.route("/activate")
class ActivateApi(Resource):
@console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(active_parser)
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.response(
200,
"Account activated successfully",
@ -85,19 +99,19 @@ class ActivateApi(Resource):
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):
args = active_parser.parse_args()
args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"]
account.name = args["name"]
account.name = args.name
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()

View File

@ -1,12 +1,26 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/api-key-auth/data-source")
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200

View File

@ -5,12 +5,11 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__)

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource):
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
if args["language"] in languages:
language = args["language"]
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token}
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"])
token_data = AccountService.get_email_register_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"])
AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"}
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args["email"])
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"])
register_data = AccountService.get_email_register_data(args.token)
if not register_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"])
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args["password_confirm"])
account = self._create_new_account(email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

View File

@ -2,7 +2,8 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email")
@console_ns.expect(
console_ns.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
@console_ns.response(
200,
"Email sent successfully",
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
email=args["email"],
email=args.email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code")
@console_ns.expect(
console_ns.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
@console_ns.response(
200,
"Code verified successfully",
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
token_data = AccountService.get_reset_password_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
user_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token")
@console_ns.expect(
console_ns.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
@console_ns.response(
200,
"Password reset successfully",
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")

View File

@ -1,6 +1,7 @@
import flask_login
from flask import make_response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
from configs import dify_config
@ -23,7 +24,7 @@ from controllers.console.error import (
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login")
class LoginApi(Resource):
@ -47,41 +78,36 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
def post(self):
"""Authenticate user and login."""
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
args = LoginPayload.model_validate(console_ns.payload)
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
# TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try:
if invitation:
data = invitation.get("data", {})
data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
if invitee_email != args.email:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args["email"], args["password"])
account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -97,7 +123,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -134,25 +160,21 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailPayload.model_validate(console_ns.payload)
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args["email"],
email=args.email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
token = AccountService.send_email_code_login_email(email=args.email, language=language)
else:
raise AccountNotFound()
else:
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args["email"]
language = args["language"]
user_email = args.email
language = args.language
token_data = AccountService.get_email_code_login_data(args["token"])
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
if token_data["email"] != args.email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
if token_data["code"] != args.code:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
@ -20,15 +21,34 @@ R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
payload = OAuthProviderRequest.model_validate(request.get_json())
redirect_uri = payload.redirect_uri
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = (
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
payload = OAuthTokenRequest.model_validate(request.get_json())
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
if not payload.code:
raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{

View File

@ -1,6 +1,8 @@
import base64
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription")
class Subscription(Resource):
@ -18,20 +49,9 @@ class Subscription(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try:
click_id = args["click_id"]
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
@ -9,16 +10,28 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
doc_name: str = Field(..., description="Compliance document name")
console_ns.schema_model(
ComplianceDownloadQuery.__name__,
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/compliance/download")
class ComplianceApi(Resource):
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
@console_ns.doc("download_compliance_document")
@console_ns.doc(description="Get compliance document download link")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
from controllers.console import console_ns
@ -35,20 +37,26 @@ recommended_app_list_fields = {
}
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.expect(parser_apps)
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
args = parser_apps.parse_args()
language = args.get("language")
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
language = args.language
if language and language in languages:
language_prefix = language
elif current_user and current_user.interface_language:

View File

@ -1,13 +1,13 @@
import os
from flask import session
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@ -15,6 +15,18 @@ from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30)
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/init")
class InitValidateAPI(Resource):
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
@console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect(
console_ns.model(
"InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
)
)
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
@console_ns.response(
201,
"Success",
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
if tenant_count > 0:
raise AlreadySetupError()
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"]
payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False

View File

@ -1,7 +1,8 @@
import urllib.parse
import httpx
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services
from controllers.common import helpers
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
}
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(parser_upload)
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@marshal_with(file_fields_with_signed_url)
def post(self):
args = parser_upload.parse_args()
url = args["url"]
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
try:
resp = ssrf_proxy.head(url=url)

View File

@ -1,8 +1,9 @@
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
password: str = Field(..., description="Admin password")
language: str | None = Field(default=None, description="Admin language")
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
console_ns.schema_model(
SetupRequestPayload.__name__,
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/setup")
class SetupApi(Resource):
@ -42,17 +63,7 @@ class SetupApi(Resource):
@console_ns.doc("setup_system")
@console_ns.doc(description="Initialize system setup with admin account")
@console_ns.expect(
console_ns.model(
"SetupRequest",
{
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
@console_ns.response(
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
)
@ -72,22 +83,15 @@ class SetupApi(Resource):
if not get_init_validate_status():
raise NotInitValidateError()
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = SetupRequestPayload.model_validate(console_ns.payload)
# setup
RegisterService.setup(
email=args["email"],
name=args["name"],
password=args["password"],
email=args.email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),
language=args["language"],
language=args.language,
)
return {"result": "success"}, 201

View File

@ -2,8 +2,10 @@ import json
import logging
import httpx
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config
@ -11,8 +13,14 @@ from . import console_ns
logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
console_ns.schema_model(
VersionQuery.__name__,
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
class VersionApi(Resource):
@console_ns.doc("check_version_update")
@console_ns.doc(description="Check for application version updates")
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[VersionQuery.__name__])
@console_ns.response(
200,
"Success",
@ -37,7 +45,7 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
args = parser.parse_args()
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
@ -57,16 +65,16 @@ class VersionApi(Resource):
try:
response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
params={"current_version": args.current_version},
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
)
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"]
result["version"] = args.current_version
return result
content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]

View File

@ -37,7 +37,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
class AccountDeletionFeedbackPayload(BaseModel):
email: str
email: EmailStr
feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel):
token: str
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel):
email: str
email: EmailStr
language: str | None = None
phase: str | None = None
token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel):
email: str
email: EmailStr
code: str
token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel):
new_email: str
new_email: EmailStr
token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel):
email: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
email: EmailStr
def reg(cls: type[BaseModel]):

View File

@ -1,7 +1,8 @@
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
import services
@ -11,6 +12,26 @@ from extensions.ext_database import db
from services.account_service import TenantService
from services.file_service import FileService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class FileSignatureQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp used in the signature")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
class FilePreviewQuery(FileSignatureQuery):
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
files_ns.schema_model(
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
files_ns.schema_model(
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/<uuid:file_id>/image-preview")
class ImagePreviewApi(Resource):
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
if not timestamp or not nonce or not sign:
return {"content": "Invalid request."}, 400
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
timestamp = args.timestamp
nonce = args.nonce
sign = args.sign
try:
generator, mimetype = FileService(db.engine).get_image_preview(
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
parser = (
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
return {"content": "Invalid request."}, 400
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id,
timestamp=args["timestamp"],
nonce=args["nonce"],
sign=args["sign"],
timestamp=args.timestamp,
nonce=args.nonce,
sign=args.sign,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
response.headers["Accept-Ranges"] = "bytes"
if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]:
if args.as_attachment:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"

View File

@ -1,7 +1,8 @@
from urllib.parse import quote
from flask import Response
from flask_restx import Resource, reqparse
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ToolFileQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp")
nonce: str = Field(..., description="Random nonce")
sign: str = Field(..., description="HMAC signature")
as_attachment: bool = Field(default=False, description="Download as attachment")
files_ns.schema_model(
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
class ToolFileApi(Resource):
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
def get(self, file_id, extension):
file_id = str(file_id)
parser = (
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not verify_tool_file_signature(
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
):
args = ToolFileQuery.model_validate(request.args.to_dict())
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
raise Forbidden("Invalid request.")
try:
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
)
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]:
if args.as_attachment:
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"

View File

@ -1,40 +1,45 @@
from mimetypes import guess_extension
from flask_restx import Resource, reqparse
from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
from controllers.files import files_ns
from controllers.inner_api.plugin.wraps import get_user
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model
# Define parser for both documentation and validation
upload_parser = (
reqparse.RequestParser()
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
.add_argument(
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
)
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
from ..common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from ..console.wraps import setup_required
from ..files import files_ns
from ..inner_api.plugin.wraps import get_user
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class PluginUploadQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp for signature verification")
nonce: str = Field(..., description="Random nonce for signature verification")
sign: str = Field(..., description="HMAC signature")
tenant_id: str = Field(..., description="Tenant identifier")
user_id: str | None = Field(default=None, description="User identifier")
files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource):
@setup_required
@files_ns.expect(upload_parser)
@files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
@files_ns.doc("upload_plugin_file")
@files_ns.doc(description="Upload a file for plugin usage with signature verification")
@files_ns.doc(
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
# Parse and validate all arguments
args = upload_parser.parse_args()
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
file: FileStorage = args["file"]
timestamp: str = args["timestamp"]
nonce: str = args["nonce"]
sign: str = args["sign"]
tenant_id: str = args["tenant_id"]
user_id: str | None = args.get("user_id")
file: FileStorage | None = request.files.get("file")
if file is None:
raise Forbidden("File is required.")
timestamp = args.timestamp
nonce = args.nonce
sign = args.sign
tenant_id = args.tenant_id
user_id = args.user_id
user = get_user(tenant_id, user_id)
filename: str | None = file.filename

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.file import FileTransferMethod, FileType, FileUploadConfig
@ -98,6 +99,7 @@ class VariableEntityType(StrEnum):
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
@ -118,6 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -129,6 +132,17 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
"""

View File

@ -770,7 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
logger.debug("Conversation name generation running as daemon thread")
def _save_message(
self,

View File

@ -99,6 +99,15 @@ class BaseAppGenerator:
if value is None:
return None
# Treat empty placeholders for optional file inputs as unset
if (
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required
):
# Treat empty string (frontend default) or empty list as unset
if not value and isinstance(value, (str, list)):
return None
if variable_entity.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,

View File

@ -156,79 +156,86 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation:
conversation = Conversation(
try:
if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.flush()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
message = Message(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(conversation)
db.session.add(message)
db.session.flush()
db.session.refresh(message)
message_files = []
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
message_files.append(message_file)
if message_files:
db.session.add_all(message_files)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(
app_id=app_config.app_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()
return conversation, message
return conversation, message
except Exception:
db.session.rollback()
raise
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
"""

View File

@ -366,7 +366,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
logger.debug("Conversation name generation running as daemon thread")
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
"""

View File

@ -1,4 +1,6 @@
import hashlib
import logging
import time
from threading import Thread
from typing import Union
@ -31,6 +33,7 @@ from core.app.entities.task_entities import (
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@ -68,6 +71,8 @@ class MessageCycleManager:
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
time.sleep(1)
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
@ -76,7 +81,7 @@ class MessageCycleManager:
"query": query,
},
)
thread.daemon = True
thread.start()
return thread
@ -98,15 +103,23 @@ class MessageCycleManager:
return
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
conversation.name = name
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
query_hash = hashlib.md5(query.encode()).hexdigest()[:16]
cache_key = f"conv_name:{conversation_id}:{query_hash}"
cached_name = redis_client.get(cache_key)
if cached_name:
name = cached_name.decode("utf-8")
else:
try:
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
redis_client.setex(cache_key, 3600, name)
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
name = query[:47] + "..." if len(query) > 50 else query
conversation.name = name
db.session.commit()
db.session.close()

View File

@ -296,7 +296,7 @@ class AliyunDataTrace(BaseTraceInstance):
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
return node_span
except Exception as e:
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
logger.warning("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
return None
def build_workflow_task_span(

View File

@ -21,6 +21,7 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
@ -48,6 +49,7 @@ class TraceClient:
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ACS_ARMS_SERVICE_FEATURE: "genai_app",
}
)
self.span_builder = SpanBuilder(self.resource)
@ -75,10 +77,10 @@ class TraceClient:
if response.status_code == 405:
return True
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
logger.warning("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
logger.warning("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self) -> str:
@ -116,7 +118,7 @@ class TraceClient:
try:
self.exporter.export(spans_to_export)
except Exception as e:
logger.debug("Error exporting spans: %s", e)
logger.warning("Error exporting spans: %s", e)
def shutdown(self) -> None:
with self.condition:

View File

@ -1,6 +1,8 @@
from enum import StrEnum
from typing import Final
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
# Public attributes
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"

View File

@ -377,20 +377,20 @@ class OpsTraceManager:
return app_model_config
@classmethod
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str | None):
"""
Update app tracing config
:param app_id: app id
:param enabled: enabled
:param tracing_provider: tracing provider
:param tracing_provider: tracing provider (None when disabling)
:return:
"""
# auth check
try:
if enabled or tracing_provider is not None:
if tracing_provider is not None:
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
if not app_config:

View File

@ -1,7 +1,9 @@
"""Document loader helpers."""
import concurrent.futures
from typing import NamedTuple, cast
from typing import NamedTuple
import charset_normalizer
class FileEncoding(NamedTuple):
@ -27,14 +29,14 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
sample_size: The number of bytes to read for encoding detection. Default is 1MB.
For large files, reading only a sample is sufficient and prevents timeout.
"""
import chardet
def read_and_detect(file_path: str):
with open(file_path, "rb") as f:
# Read only a sample of the file for encoding detection
# This prevents timeout on large files while still providing accurate encoding detection
rawdata = f.read(sample_size)
return cast(list[dict], chardet.detect_all(rawdata))
def read_and_detect(filename: str):
rst = charset_normalizer.from_path(filename)
best = rst.best()
if best is None:
return []
file_encoding = FileEncoding(encoding=best.encoding, confidence=best.coherence, language=best.language)
return [file_encoding]
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(read_and_detect, file_path)

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from typing import Any, cast
from urllib.parse import unquote
import chardet
import charset_normalizer
import cloudscraper
from readabilipy import simple_json_from_html_string
@ -69,9 +69,12 @@ def get_url(url: str, user_agent: str | None = None) -> str:
if response.status_code != 200:
return f"URL returned status code {response.status_code}."
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
# Detect encoding using charset_normalizer
detected_encoding = charset_normalizer.from_bytes(response.content).best()
if detected_encoding:
encoding = detected_encoding.encoding
else:
encoding = "utf-8"
if encoding:
try:
content = response.content.decode(encoding)

View File

@ -7,7 +7,7 @@ import tempfile
from collections.abc import Mapping, Sequence
from typing import Any
import chardet
import charset_normalizer
import docx
import pandas as pd
import pypandoc
@ -228,9 +228,12 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
@ -247,9 +250,12 @@ def _extract_text_from_plain_text(file_content: bytes) -> str:
def _extract_text_from_json(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
@ -269,9 +275,12 @@ def _extract_text_from_json(file_content: bytes) -> str:
def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
@ -424,9 +433,12 @@ def _extract_text_from_file(file: File):
def _extract_text_from_csv(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:

View File

@ -1,3 +1,8 @@
from typing import Any
from jsonschema import Draft7Validator, ValidationError
from core.app.app_config.entities import VariableEntityType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -15,6 +20,7 @@ class StartNode(Node[StartNodeData]):
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
self._validate_and_normalize_json_object_inputs(node_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
@ -24,3 +30,27 @@ class StartNode(Node[StartNodeData]):
outputs = dict(node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None:
for variable in self.node_data.variables:
if variable.type != VariableEntityType.JSON_OBJECT:
continue
key = variable.variable
value = node_inputs.get(key)
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
if not isinstance(value, dict):
raise ValueError(f"{key} must be a JSON object")
schema = variable.json_schema
if not schema:
continue
try:
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = value

View File

@ -256,7 +256,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
now = datetime_utils.naive_utc_now()
last_update = _get_last_update_timestamp(cache_key)
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: # type: ignore
update_values["last_used"] = values.last_used
_set_last_update_timestamp(cache_key, now)

View File

@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
import redis
from redis import RedisError
@ -245,7 +245,12 @@ def init_app(app: DifyApp):
app.extensions["redis"] = redis_client
def redis_fallback(default_return: Any | None = None):
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def redis_fallback(default_return: T | None = None): # type: ignore
"""
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
@ -253,9 +258,9 @@ def redis_fallback(default_return: Any | None = None):
default_return: The value to return when a Redis operation fails. Defaults to None.
"""
def decorator(func: Callable):
def decorator(func: Callable[P, R]):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs):
try:
return func(*args, **kwargs)
except RedisError as e:

View File

@ -10,12 +10,13 @@ import uuid
from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restx import fields
from pydantic import BaseModel
from pydantic.functional_validators import AfterValidator
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
@ -103,6 +104,9 @@ def email(email):
raise ValueError(error)
EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value):
if value == "":
return str(value)

View File

@ -11,7 +11,7 @@ dependencies = [
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.5.2",
"chardet~=5.1.0",
"charset-normalizer>=3.4.4",
"flask~=3.1.2",
"flask-compress>=1.17,<1.18",
"flask-cors~=6.0.0",
@ -91,6 +91,7 @@ dependencies = [
"weaviate-client==4.17.0",
"apscheduler>=3.11.0",
"weave>=0.52.16",
"jsonschema>=4.25.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

10
api/pyrefly.toml Normal file
View File

@ -0,0 +1,10 @@
project-includes = ["."]
project-excludes = [
"tests/",
".venv",
"migrations/",
"core/rag",
]
python-platform = "linux"
python-version = "3.11.0"
infer-with-first-use = false

View File

@ -1259,7 +1259,7 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str):
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
"""
Setup dify
@ -1267,6 +1267,7 @@ class RegisterService:
:param name: username
:param password: password
:param ip_address: ip address
:param language: language
"""
try:
account = AccountService.create_account(
@ -1414,7 +1415,7 @@ class RegisterService:
return data is not None
@classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str):
def revoke_token(cls, workspace_id: str | None, email: str | None, token: str):
if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
@ -1423,7 +1424,9 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
def get_invitation_if_token_valid(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None

View File

@ -265,3 +265,82 @@ def test_validate_inputs_with_default_value():
)
assert result == [{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}]
def test_validate_inputs_optional_file_with_empty_string():
"""Test that optional FILE variable with empty string returns None"""
base_app_generator = BaseAppGenerator()
var_file = VariableEntity(
variable="test_file",
label="test_file",
type=VariableEntityType.FILE,
required=False,
)
result = base_app_generator._validate_inputs(
variable_entity=var_file,
value="",
)
assert result is None
def test_validate_inputs_optional_file_list_with_empty_list():
"""Test that optional FILE_LIST variable with empty list returns None"""
base_app_generator = BaseAppGenerator()
var_file_list = VariableEntity(
variable="test_file_list",
label="test_file_list",
type=VariableEntityType.FILE_LIST,
required=False,
)
result = base_app_generator._validate_inputs(
variable_entity=var_file_list,
value=[],
)
assert result is None
def test_validate_inputs_required_file_with_empty_string_fails():
"""Test that required FILE variable with empty string still fails validation"""
base_app_generator = BaseAppGenerator()
var_file = VariableEntity(
variable="test_file",
label="test_file",
type=VariableEntityType.FILE,
required=True,
)
with pytest.raises(ValueError) as exc_info:
base_app_generator._validate_inputs(
variable_entity=var_file,
value="",
)
assert "must be a file" in str(exc_info.value)
def test_validate_inputs_optional_file_with_empty_string_ignores_default():
"""Test that optional FILE variable with empty string returns None, not the default"""
base_app_generator = BaseAppGenerator()
var_file = VariableEntity(
variable="test_file",
label="test_file",
type=VariableEntityType.FILE,
required=False,
default={"id": "file123", "name": "default.pdf"},
)
# When value is empty string (from frontend), should return None, not default
result = base_app_generator._validate_inputs(
variable_entity=var_file,
value="",
)
assert result is None

View File

@ -1,3 +1,5 @@
from types import SimpleNamespace
import pytest
from core.tools.utils.web_reader_tool import (
@ -103,7 +105,10 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.Monk
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
# readability → a dict that maps to Article, then FULL_TEMPLATE
def fake_simple_json_from_html_string(html, use_readability=True):
@ -134,7 +139,9 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
# readability returns empty plain_text
monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []})
@ -162,7 +169,9 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
monkeypatch.setattr(
mod,
"simple_json_from_html_string",
@ -234,7 +243,10 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.Mo
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
mock_best = SimpleNamespace(encoding="utf-8")
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
monkeypatch.setattr(
mod,
"simple_json_from_html_string",

View File

@ -0,0 +1,227 @@
import time
import pytest
from pydantic import ValidationError as PydanticValidationError
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
def make_start_node(user_inputs, variables):
variable_pool = VariablePool(
system_variables=SystemVariable(),
user_inputs=user_inputs,
conversation_variables=[],
)
config = {
"id": "start",
"data": StartNodeData(title="Start", variables=variables).model_dump(),
}
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
return StartNode(
id="start",
config=config,
graph_init_params=GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="wf",
graph_config={},
user_id="u",
user_from="account",
invoke_from="debugger",
call_depth=0,
),
graph_runtime_state=graph_runtime_state,
)
def test_json_object_valid_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
node = make_start_node(user_inputs, variables)
result = node._run()
assert result.outputs["profile"] == {"age": 20, "name": "Tom"}
def test_json_object_invalid_json_string():
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
# Missing closing brace makes this invalid JSON
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
def test_json_object_valid_json_but_not_object(value):
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
user_inputs = {"profile": value}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
def test_json_object_does_not_match_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match=r"JSON object for 'profile' does not match schema:"):
node._run()
def test_json_object_missing_required_schema_field():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
# Missing required field "name"
user_inputs = {"profile": {"age": 20}}
node = make_start_node(user_inputs, variables)
with pytest.raises(
ValueError, match=r"JSON object for 'profile' does not match schema: 'name' is a required property"
):
node._run()
def test_json_object_required_variable_missing_from_inputs():
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
user_inputs = {}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile is required in input form"):
node._run()
def test_json_object_invalid_json_schema_string():
variable = VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
# Bypass pydantic type validation on assignment to simulate an invalid JSON schema string
variable.json_schema = "{invalid-json-schema"
variables = [variable]
user_inputs = {"profile": '{"age": 20}'}
# Invalid json_schema string should be rejected during node data hydration
with pytest.raises(PydanticValidationError):
make_start_node(user_inputs, variables)
def test_json_object_optional_variable_not_provided():
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=False,
)
]
user_inputs = {}
node = make_start_node(user_inputs, variables)
# Current implementation raises a validation error even when the variable is optional
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()

File diff suppressed because it is too large Load Diff

View File

@ -233,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
# Database type, supported values are `postgresql` and `mysql`
DB_TYPE=postgresql
# For MySQL, only `root` user is supported for now
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=db_postgres
@ -1076,24 +1076,10 @@ MAX_TREE_DEPTH=50
# ------------------------------
# Environment Variables for database Service
# ------------------------------
# The name of the default postgres user.
POSTGRES_USER=${DB_USERNAME}
# The password for the default postgres user.
POSTGRES_PASSWORD=${DB_PASSWORD}
# The name of the default postgres database.
POSTGRES_DB=${DB_DATABASE}
# Postgres data directory
PGDATA=/var/lib/postgresql/data/pgdata
# MySQL Default Configuration
# The name of the default mysql user.
MYSQL_USERNAME=${DB_USERNAME}
# The password for the default mysql user.
MYSQL_PASSWORD=${DB_PASSWORD}
# The name of the default mysql database.
MYSQL_DATABASE=${DB_DATABASE}
# MySQL data directory
MYSQL_HOST_VOLUME=./volumes/mysql/data
# ------------------------------

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.10.1
image: langgenius/dify-api:1.10.1-fix.1
restart: always
environment:
# Use the shared environment variables.
@ -41,7 +41,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.10.1
image: langgenius/dify-api:1.10.1-fix.1
restart: always
environment:
# Use the shared environment variables.
@ -78,7 +78,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.10.1
image: langgenius/dify-api:1.10.1-fix.1
restart: always
environment:
# Use the shared environment variables.
@ -106,7 +106,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.10.1
image: langgenius/dify-web:1.10.1-fix.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -139,9 +139,9 @@ services:
- postgresql
restart: always
environment:
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify}
POSTGRES_USER: ${DB_USERNAME:-postgres}
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
POSTGRES_DB: ${DB_DATABASE:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
command: >
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
@ -161,7 +161,7 @@ services:
"-h",
"db_postgres",
"-U",
"${PGUSER:-postgres}",
"${DB_USERNAME:-postgres}",
"-d",
"${DB_DATABASE:-dify}",
]
@ -176,8 +176,8 @@ services:
- mysql
restart: always
environment:
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${DB_DATABASE:-dify}
command: >
--max_connections=1000
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
@ -193,7 +193,7 @@ services:
"ping",
"-u",
"root",
"-p${MYSQL_PASSWORD:-difyai123456}",
"-p${DB_PASSWORD:-difyai123456}",
]
interval: 1s
timeout: 3s

View File

@ -9,8 +9,8 @@ services:
env_file:
- ./middleware.env
environment:
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify}
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
POSTGRES_DB: ${DB_DATABASE:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
command: >
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
@ -32,9 +32,9 @@ services:
"-h",
"db_postgres",
"-U",
"${PGUSER:-postgres}",
"${DB_USERNAME:-postgres}",
"-d",
"${POSTGRES_DB:-dify}",
"${DB_DATABASE:-dify}",
]
interval: 1s
timeout: 3s
@ -48,8 +48,8 @@ services:
env_file:
- ./middleware.env
environment:
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${DB_DATABASE:-dify}
command: >
--max_connections=1000
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
@ -67,7 +67,7 @@ services:
"ping",
"-u",
"root",
"-p${MYSQL_PASSWORD:-difyai123456}",
"-p${DB_PASSWORD:-difyai123456}",
]
interval: 1s
timeout: 3s

View File

@ -455,13 +455,7 @@ x-shared-env: &shared-api-worker-env
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
MYSQL_USERNAME: ${MYSQL_USERNAME:-${DB_USERNAME}}
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-${DB_PASSWORD}}
MYSQL_DATABASE: ${MYSQL_DATABASE:-${DB_DATABASE}}
MYSQL_HOST_VOLUME: ${MYSQL_HOST_VOLUME:-./volumes/mysql/data}
SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox}
SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release}
@ -637,7 +631,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.10.1
image: langgenius/dify-api:1.10.1-fix.1
restart: always
environment:
# Use the shared environment variables.
@ -676,7 +670,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.10.1
image: langgenius/dify-api:1.10.1-fix.1
restart: always
environment:
# Use the shared environment variables.
@ -713,7 +707,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.10.1
image: langgenius/dify-api:1.10.1-fix.1
restart: always
environment:
# Use the shared environment variables.
@ -741,7 +735,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.10.1
image: langgenius/dify-web:1.10.1-fix.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -774,9 +768,9 @@ services:
- postgresql
restart: always
environment:
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify}
POSTGRES_USER: ${DB_USERNAME:-postgres}
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
POSTGRES_DB: ${DB_DATABASE:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
command: >
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
@ -796,7 +790,7 @@ services:
"-h",
"db_postgres",
"-U",
"${PGUSER:-postgres}",
"${DB_USERNAME:-postgres}",
"-d",
"${DB_DATABASE:-dify}",
]
@ -811,8 +805,8 @@ services:
- mysql
restart: always
environment:
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${DB_DATABASE:-dify}
command: >
--max_connections=1000
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
@ -828,7 +822,7 @@ services:
"ping",
"-u",
"root",
"-p${MYSQL_PASSWORD:-difyai123456}",
"-p${DB_PASSWORD:-difyai123456}",
]
interval: 1s
timeout: 3s

View File

@ -4,6 +4,7 @@
# Database Configuration
# Database type, supported values are `postgresql` and `mysql`
DB_TYPE=postgresql
# For MySQL, only `root` user is supported for now
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=db_postgres
@ -11,11 +12,6 @@ DB_PORT=5432
DB_DATABASE=dify
# PostgreSQL Configuration
POSTGRES_USER=${DB_USERNAME}
# The password for the default postgres user.
POSTGRES_PASSWORD=${DB_PASSWORD}
# The name of the default postgres database.
POSTGRES_DB=${DB_DATABASE}
# postgres data directory
PGDATA=/var/lib/postgresql/data/pgdata
PGDATA_HOST_VOLUME=./volumes/db/data
@ -65,11 +61,6 @@ POSTGRES_STATEMENT_TIMEOUT=0
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
# MySQL Configuration
MYSQL_USERNAME=${DB_USERNAME}
# MySQL password
MYSQL_PASSWORD=${DB_PASSWORD}
# MySQL database name
MYSQL_DATABASE=${DB_DATABASE}
# MySQL data directory host volume
MYSQL_HOST_VOLUME=./volumes/mysql/data

View File

@ -61,13 +61,13 @@ if $web_modified; then
lint-staged
if $web_ts_modified; then
echo "Running TypeScript type-check"
if ! pnpm run type-check; then
echo "Type check failed. Please run 'pnpm run type-check' to fix the errors."
echo "Running TypeScript type-check:tsgo"
if ! pnpm run type-check:tsgo; then
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check"
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running unit tests check"

View File

@ -251,6 +251,7 @@ const AgentTools: FC = () => {
{!item.notAuthor && (
<Tooltip
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
needsDelay={false}
>
<div className='cursor-pointer rounded-md p-1 hover:bg-black/5' onClick={() => {
setCurrentTool(item)

View File

@ -0,0 +1,57 @@
import React from 'react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen, waitFor } from '@testing-library/react'
import nock from 'nock'
import GithubStar from './index'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/langgenius/dify'
const renderWithQueryClient = () => {
const queryClient = new QueryClient({
defaultOptions: { queries: { retry: false } },
})
return render(
<QueryClientProvider client={queryClient}>
<GithubStar className='test-class' />
</QueryClientProvider>,
)
}
const mockGithubStar = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST).get(GITHUB_PATH).delay(delayMs).reply(status, body)
}
describe('GithubStar', () => {
beforeEach(() => {
nock.cleanAll()
})
// Shows fetched star count when request succeeds
it('should render fetched star count', async () => {
mockGithubStar(200, { stargazers_count: 123456 })
renderWithQueryClient()
expect(await screen.findByText('123,456')).toBeInTheDocument()
})
// Falls back to default star count when request fails
it('should render default star count on error', async () => {
mockGithubStar(500, {})
renderWithQueryClient()
expect(await screen.findByText('110,918')).toBeInTheDocument()
})
// Renders loader while fetching data
it('should show loader while fetching', async () => {
mockGithubStar(200, { stargazers_count: 222222 }, 50)
const { container } = renderWithQueryClient()
expect(container.querySelector('.animate-spin')).toBeInTheDocument()
await waitFor(() => expect(screen.getByText('222,222')).toBeInTheDocument())
})
})

View File

@ -76,7 +76,7 @@ const ProviderCard: FC<Props> = ({
className='grow'
variant='secondary'
>
<a href={`${getPluginLinkInMarketplace(payload)}?language=${locale}${theme ? `&theme=${theme}` : ''}`} target='_blank' className='flex items-center gap-0.5'>
<a href={getPluginLinkInMarketplace(payload, { language: locale, theme })} target='_blank' className='flex items-center gap-0.5'>
{t('plugin.detailPanel.operation.detail')}
<RiArrowRightUpLine className='h-4 w-4' />
</a>

View File

@ -94,7 +94,6 @@ export default combine(
// orignal ts/no-var-requires
'ts/no-require-imports': 'off',
'no-console': 'off',
'react-hooks/exhaustive-deps': 'warn',
'react/display-name': 'off',
'array-callback-return': ['error', {
allowImplicit: false,
@ -257,4 +256,9 @@ export default combine(
},
},
oxlint.configs['flat/recommended'],
{
rules: {
'react-hooks/exhaustive-deps': 'warn',
},
},
)

View File

@ -197,258 +197,6 @@ const translation = {
},
contentEnableLabel: 'مدیریت محتوا فعال شده است',
},
generate: {
title: 'تولید کننده دستورالعمل',
description: 'تولید کننده دستورالعمل از مدل تنظیم شده برای بهینه سازی دستورالعمل‌ها برای کیفیت بالاتر و ساختار بهتر استفاده می‌کند. لطفاً دستورالعمل‌های واضح و دقیقی بنویسید.',
tryIt: 'امتحان کنید',
instruction: 'دستورالعمل‌ها',
instructionPlaceHolder: 'دستورالعمل‌های واضح و خاصی بنویسید.',
generate: 'تولید',
resTitle: 'دستورالعمل تولید شده',
noDataLine1: 'موارد استفاده خود را در سمت چپ توصیف کنید،',
noDataLine2: 'پیش‌نمایش ارکستراسیون در اینجا نشان داده خواهد شد.',
apply: 'اعمال',
loading: 'در حال ارکستراسیون برنامه برای شما...',
overwriteTitle: 'آیا تنظیمات موجود را لغو می‌کنید؟',
overwriteMessage: 'اعمال این دستورالعمل تنظیمات موجود را لغو خواهد کرد.',
template: {
pythonDebugger: {
name: 'اشکال‌زدای پایتون',
instruction: 'یک بات که می‌تواند بر اساس دستورالعمل شما کد تولید و اشکال‌زدایی کند',
},
translation: {
name: 'ترجمه',
instruction: 'یک مترجم که می‌تواند چندین زبان را ترجمه کند',
},
professionalAnalyst: {
name: 'تحلیلگر حرفه‌ای',
instruction: 'استخراج بینش‌ها، شناسایی ریسک و خلاصه‌سازی اطلاعات کلیدی از گزارش‌های طولانی به یک یادداشت کوتاه',
},
excelFormulaExpert: {
name: 'کارشناس فرمول اکسل',
instruction: 'یک چت‌بات که می‌تواند به کاربران مبتدی کمک کند فرمول‌های اکسل را بر اساس دستورالعمل‌های کاربر درک، استفاده و ایجاد کنند',
},
travelPlanning: {
name: 'برنامه‌ریزی سفر',
instruction: 'دستیار برنامه‌ریزی سفر یک ابزار هوشمند است که به کاربران کمک می‌کند سفرهای خود را به راحتی برنامه‌ریزی کنند',
},
SQLSorcerer: {
name: 'جادوگر SQL',
instruction: 'تبدیل زبان روزمره به پرس و جوهای SQL',
},
GitGud: {
name: 'Git gud',
instruction: 'تولید دستورات مناسب Git بر اساس اقدامات توصیف شده توسط کاربر در کنترل نسخه',
},
meetingTakeaways: {
name: 'نتایج جلسات',
instruction: 'خلاصه‌سازی جلسات به صورت مختصر شامل موضوعات بحث، نکات کلیدی و موارد اقدام',
},
writingsPolisher: {
name: 'پولیش‌گر نوشته‌ها',
instruction: 'استفاده از تکنیک‌های ویرایش پیشرفته برای بهبود نوشته‌های شما',
},
},
},
resetConfig: {
title: 'بازنشانی تأیید می‌شود؟',
message: 'بازنشانی تغییرات را لغو کرده و تنظیمات منتشر شده آخر را بازیابی می‌کند.',
},
errorMessage: {
nameOfKeyRequired: 'نام کلید: {{key}} مورد نیاز است',
valueOfVarRequired: 'مقدار {{key}} نمی‌تواند خالی باشد',
queryRequired: 'متن درخواست مورد نیاز است.',
waitForResponse: 'لطفاً منتظر پاسخ به پیام قبلی بمانید.',
waitForBatchResponse: 'لطفاً منتظر پاسخ به کار دسته‌ای بمانید.',
notSelectModel: 'لطفاً یک مدل را انتخاب کنید',
waitForImgUpload: 'لطفاً منتظر بارگذاری تصویر بمانید',
},
chatSubTitle: 'دستورالعمل‌ها',
completionSubTitle: 'پیشوند پرس و جو',
promptTip: 'دستورالعمل‌ها و محدودیت‌ها پاسخ‌های AI را هدایت می‌کنند. متغیرهایی مانند {{input}} را درج کنید. این دستورالعمل برای کاربران قابل مشاهده نخواهد بود.',
formattingChangedTitle: 'قالب‌بندی تغییر کرد',
formattingChangedText: 'تغییر قالب‌بندی منطقه اشکال‌زدایی را بازنشانی خواهد کرد، آیا مطمئن هستید؟',
variableTitle: 'متغیرها',
variableTip: 'کاربران متغیرها را در فرم پر می‌کنند و به طور خودکار متغیرها را در دستورالعمل‌ها جایگزین می‌کنند.',
notSetVar: 'متغیرها به کاربران اجازه می‌دهند که کلمات پرس و جو یا جملات ابتدایی را هنگام پر کردن فرم معرفی کنند. شما می‌توانید سعی کنید "{{input}}" را در کلمات پرس و جو وارد کنید.',
autoAddVar: 'متغیرهای تعریف نشده‌ای که در پیش‌پرسش ذکر شده‌اند، آیا می‌خواهید آنها را به فرم ورودی کاربر اضافه کنید؟',
variableTable: {
key: 'کلید متغیر',
name: 'نام فیلد ورودی کاربر',
optional: 'اختیاری',
type: 'نوع ورودی',
action: 'اقدامات',
typeString: 'رشته',
typeSelect: 'انتخاب',
},
varKeyError: {
canNoBeEmpty: '{{key}} مطلوب',
tooLong: '{{key}} طولانی است. نمی‌تواند بیش از 30 کاراکتر باشد',
notValid: '{{key}} نامعتبر است. فقط می‌تواند شامل حروف، اعداد و زیرخط باشد',
notStartWithNumber: '{{key}} نمی‌تواند با عدد شروع شود',
keyAlreadyExists: '{{key}} از قبل وجود دارد',
},
otherError: {
promptNoBeEmpty: 'پرس و جو نمی‌تواند خالی باشد',
historyNoBeEmpty: 'تاریخچه مکالمه باید در پرس و جو تنظیم شود',
queryNoBeEmpty: 'پرس و جو باید در پرس و جو تنظیم شود',
},
variableConfig: {
'addModalTitle': 'افزودن فیلد ورودی',
'editModalTitle': 'ویرایش فیلد ورودی',
'description': 'تنظیم برای متغیر {{varName}}',
'fieldType': 'نوع فیلد',
'string': 'متن کوتاه',
'text-input': 'متن کوتاه',
'paragraph': 'پاراگراف',
'select': 'انتخاب',
'number': 'عدد',
'notSet': 'تنظیم نشده، سعی کنید {{input}} را در پرس و جو وارد کنید',
'stringTitle': 'گزینه‌های جعبه متن فرم',
'maxLength': 'حداکثر طول',
'options': 'گزینه‌ها',
'addOption': 'افزودن گزینه',
'apiBasedVar': 'متغیر مبتنی بر API',
'varName': 'نام متغیر',
'labelName': 'نام برچسب',
'inputPlaceholder': 'لطفاً وارد کنید',
'content': 'محتوا',
'required': 'مورد نیاز',
'hide': 'مخفی کردن',
'errorMsg': {
labelNameRequired: 'نام برچسب مورد نیاز است',
varNameCanBeRepeat: 'نام متغیر نمی‌تواند تکراری باشد',
atLeastOneOption: 'حداقل یک گزینه مورد نیاز است',
optionRepeat: 'گزینه‌های تکراری وجود دارد',
},
},
vision: {
name: 'بینایی',
description: 'فعال کردن بینایی به مدل اجازه می‌دهد تصاویر را دریافت کند و به سوالات مربوط به آنها پاسخ دهد.',
settings: 'تنظیمات',
visionSettings: {
title: 'تنظیمات بینایی',
resolution: 'وضوح',
resolutionTooltip: `وضوح پایین به مدل اجازه می‌دهد نسخه 512x512 کم‌وضوح تصویر را دریافت کند و تصویر را با بودجه 65 توکن نمایش دهد. این به API اجازه می‌دهد پاسخ‌های سریع‌تری بدهد و توکن‌های ورودی کمتری برای موارد استفاده که نیاز به جزئیات بالا ندارند مصرف کند.
\n
وضوح بالا ابتدا به مدل اجازه میدهد تصویر کموضوح را ببیند و سپس قطعات جزئیات تصویر ورودی را به عنوان مربعهای 512px ایجاد کند. هر کدام از قطعات جزئیات از بودجه توکن دو برابر استفاده میکنند که در مجموع 129 توکن است.`,
high: 'بالا',
low: 'پایین',
uploadMethod: 'روش بارگذاری',
both: 'هر دو',
localUpload: 'بارگذاری محلی',
url: 'URL',
uploadLimit: 'محدودیت بارگذاری',
},
},
voice: {
name: 'صدا',
defaultDisplay: 'صدا پیش فرض',
description: 'تنظیمات تبدیل متن به گفتار',
settings: 'تنظیمات',
voiceSettings: {
title: 'تنظیمات صدا',
language: 'زبان',
resolutionTooltip: 'پشتیبانی از زبان صدای تبدیل متن به گفتار.',
voice: 'صدا',
autoPlay: 'پخش خودکار',
autoPlayEnabled: 'روشن کردن',
autoPlayDisabled: 'خاموش کردن',
},
},
openingStatement: {
title: 'شروع مکالمه',
add: 'افزودن',
writeOpener: 'نوشتن آغازگر',
placeholder: 'پیام آغازگر خود را اینجا بنویسید، می‌توانید از متغیرها استفاده کنید، سعی کنید {{variable}} را تایپ کنید.',
openingQuestion: 'سوالات آغازین',
openingQuestionPlaceholder: 'می‌توانید از متغیرها استفاده کنید، سعی کنید {{variable}} را تایپ کنید.',
noDataPlaceHolder: 'شروع مکالمه با کاربر می‌تواند به AI کمک کند تا ارتباط نزدیک‌تری با آنها برقرار کند.',
varTip: 'می‌توانید از متغیرها استفاده کنید، سعی کنید {{variable}} را تایپ کنید',
tooShort: 'حداقل 20 کلمه از پرسش اولیه برای تولید نظرات آغازین مکالمه مورد نیاز است.',
notIncludeKey: 'پرسش اولیه شامل متغیر: {{key}} نمی‌شود. لطفاً آن را به پرسش اولیه اضافه کنید.',
},
modelConfig: {
model: 'مدل',
setTone: 'تنظیم لحن پاسخ‌ها',
title: 'مدل و پارامترها',
modeType: {
chat: 'چت',
completion: 'تکمیل',
},
},
inputs: {
title: 'اشکال‌زدایی و پیش‌نمایش',
noPrompt: 'سعی کنید پرسش‌هایی را در ورودی پیش‌پرسش بنویسید',
userInputField: 'فیلد ورودی کاربر',
noVar: 'مقدار متغیر را پر کنید، که به طور خودکار در کلمات پرس و جو در هر بار شروع یک جلسه جدید جایگزین می‌شود.',
chatVarTip: 'مقدار متغیر را پر کنید، که به طور خودکار در کلمات پرس و جو در هر بار شروع یک جلسه جدید جایگزین می‌شود',
completionVarTip: 'مقدار متغیر را پر کنید، که به طور خودکار در کلمات پرس و جو در هر بار ارسال سوال جایگزین می‌شود.',
previewTitle: 'پیش‌نمایش پرس و جو',
queryTitle: 'محتوای پرس و جو',
queryPlaceholder: 'لطفاً متن درخواست را وارد کنید.',
run: 'اجرا',
},
result: 'متن خروجی',
datasetConfig: {
settingTitle: 'تنظیمات بازیابی',
knowledgeTip: 'روی دکمه "+" کلیک کنید تا دانش اضافه شود',
retrieveOneWay: {
title: 'بازیابی N به 1',
description: 'بر اساس نیت کاربر و توصیفات دانش، عامل بهترین دانش را برای پرس و جو به طور خودکار انتخاب می‌کند. بهترین برای برنامه‌هایی با دانش محدود و مشخص.',
},
retrieveMultiWay: {
title: 'بازیابی چند مسیره',
description: 'بر اساس نیت کاربر، از تمام دانش پرس و جو می‌کند، متن‌های مرتبط از منابع چندگانه بازیابی می‌کند و بهترین نتایج مطابقت با پرس و جوی کاربر را پس از مرتب‌سازی مجدد انتخاب می‌کند.',
},
rerankModelRequired: 'مدل مرتب‌سازی مجدد مورد نیاز است',
params: 'پارامترها',
top_k: 'Top K',
top_kTip: 'برای فیلتر کردن تکه‌هایی که بیشترین شباهت به سوالات کاربر دارند استفاده می‌شود. سیستم همچنین به طور دینامیک مقدار Top K را بر اساس max_tokens مدل انتخاب شده تنظیم می‌کند.',
score_threshold: 'آستانه نمره',
score_thresholdTip: 'برای تنظیم آستانه شباهت برای فیلتر کردن تکه‌ها استفاده می‌شود.',
retrieveChangeTip: 'تغییر حالت شاخص و حالت بازیابی ممکن است بر برنامه‌های مرتبط با این دانش تأثیر بگذارد.',
},
debugAsSingleModel: 'اشکال‌زدایی به عنوان مدل تک',
debugAsMultipleModel: 'اشکال‌زدایی به عنوان مدل چندگانه',
duplicateModel: 'تکراری',
publishAs: 'انتشار به عنوان',
assistantType: {
name: 'نوع دستیار',
chatAssistant: {
name: 'دستیار پایه',
description: 'ساخت دستیار مبتنی بر چت با استفاده از مدل زبان بزرگ',
},
agentAssistant: {
name: 'دستیار عامل',
description: 'ساخت یک عامل هوشمند که می‌تواند ابزارها را به طور خودکار برای تکمیل وظایف انتخاب کند',
},
},
agent: {
agentMode: 'حالت عامل',
agentModeDes: 'تنظیم نوع حالت استنتاج برای عامل',
agentModeType: {
ReACT: 'ReAct',
functionCall: 'فراخوانی تابع',
},
setting: {
name: 'تنظیمات عامل',
description: 'تنظیمات دستیار عامل به شما اجازه می‌دهد حالت عامل و ویژگی‌های پیشرفته مانند پرسش‌های ساخته شده را تنظیم کنید، فقط در نوع عامل موجود است.',
maximumIterations: {
name: 'حداکثر تکرارها',
description: 'محدود کردن تعداد تکرارهایی که دستیار عامل می‌تواند اجرا کند',
},
},
buildInPrompt: 'پرسش‌های ساخته شده',
firstPrompt: 'اولین پرسش',
nextIteration: 'تکرار بعدی',
promptPlaceholder: 'پرسش خود را اینجا بنویسید',
tools: {
name: 'ابزارها',
description: 'استفاده از ابزارها می‌تواند قابلیت‌های LLM را گسترش دهد، مانند جستجو در اینترنت یا انجام محاسبات علمی',
enabled: 'فعال',
},
},
fileUpload: {
title: 'آپلود فایل',
description: 'جعبه ورودی چت امکان آپلود تصاویر، اسناد و سایر فایل‌ها را فراهم می‌کند.',
@ -536,13 +284,10 @@ const translation = {
resTitle: 'اعلان تولید شده',
overwriteTitle: 'پیکربندی موجود را لغو کنید؟',
generate: 'تولید',
noDataLine1: 'مورد استفاده خود را در سمت چپ شرح دهید،',
apply: 'درخواست',
instruction: 'دستورالعمل',
overwriteMessage: 'اعمال این اعلان پیکربندی موجود را لغو می کند.',
instructionPlaceHolder: 'دستورالعمل های واضح و مشخص بنویسید.',
tryIt: 'آن را امتحان کنید',
noDataLine2: 'پیش نمایش ارکستراسیون در اینجا نشان داده می شود.',
loading: 'هماهنگ کردن برنامه برای شما...',
description: 'Prompt Generator از مدل پیکربندی شده برای بهینه سازی درخواست ها برای کیفیت بالاتر و ساختار بهتر استفاده می کند. لطفا دستورالعمل های واضح و دقیق بنویسید.',
press: 'فشار',

View File

@ -393,6 +393,7 @@ const translation = {
writeOpener: 'Scrieți deschizătorul',
placeholder: 'Scrieți aici mesajul de deschidere, puteți utiliza variabile, încercați să tastați {{variable}}.',
openingQuestion: 'Întrebări de deschidere',
openingQuestionPlaceholder: 'Puteți utiliza variabile, încercați să tastați {{variable}}.',
noDataPlaceHolder:
'Începerea conversației cu utilizatorul poate ajuta AI să stabilească o conexiune mai strânsă cu ei în aplicațiile conversaționale.',
varTip: 'Puteți utiliza variabile, încercați să tastați {{variable}}',

View File

@ -479,87 +479,6 @@ const translation = {
loadBalancingLeastKeyWarning: 'Za omogočanje uravnoteženja obremenitev morata biti omogočena vsaj 2 ključa.',
loadBalancingInfo: 'Privzeto uravnoteženje obremenitev uporablja strategijo Round-robin. Če se sproži omejitev hitrosti, se uporabi 1-minutno obdobje ohlajanja.',
upgradeForLoadBalancing: 'Nadgradite svoj načrt, da omogočite uravnoteženje obremenitev.',
dataSource: {
notion: {
selector: {
},
},
website: {
},
},
plugin: {
serpapi: {
},
},
apiBasedExtension: {
selector: {
},
modal: {
name: {
},
apiEndpoint: {
},
apiKey: {
},
},
},
about: {
},
appMenus: {
},
environment: {
},
appModes: {
},
datasetMenus: {
},
voiceInput: {
},
modelName: {
'gpt-3.5-turbo': 'GPT-3.5-Turbo',
'gpt-3.5-turbo-16k': 'GPT-3.5-Turbo-16K',
'gpt-4': 'GPT-4',
'gpt-4-32k': 'GPT-4-32K',
'text-davinci-003': 'Text-Davinci-003',
'text-embedding-ada-002': 'Text-Embedding-Ada-002',
'whisper-1': 'Whisper-1',
'claude-instant-1': 'Claude-Instant',
'claude-2': 'Claude-2',
},
chat: {
citation: {
},
},
promptEditor: {
context: {
item: {
},
modal: {
},
},
history: {
item: {
},
modal: {
},
},
variable: {
item: {
},
outputToolDisabledItem: {
},
modal: {
},
},
query: {
item: {
},
},
},
imageUploader: {
},
tag: {
},
discoverMore: 'Odkrijte več v',
installProvider: 'Namestitev ponudnikov modelov',
emptyProviderTitle: 'Ponudnik modelov ni nastavljen',

View File

@ -348,7 +348,6 @@ const translation = {
'description': 'Değişken ayarı {{varName}}',
'fieldType': 'Alan türü',
'string': 'Kısa Metin',
'textInput': 'Kısa Metin',
'paragraph': 'Paragraf',
'select': 'Seçim',
'number': 'Numara',
@ -364,7 +363,6 @@ const translation = {
'content': 'İçerik',
'required': 'Gerekli',
'errorMsg': {
varNameRequired: 'Değişken adı gereklidir',
labelNameRequired: 'Etiket adı gereklidir',
varNameCanBeRepeat: 'Değişken adı tekrar edemez',
atLeastOneOption: 'En az bir seçenek gereklidir',

View File

@ -596,7 +596,6 @@ const translation = {
'authorizationType': 'Yetkilendirme Türü',
'no-auth': 'Yok',
'api-key': 'API Anahtarı',
'authType': 'Yetki Türü',
'basic': 'Temel',
'bearer': 'Bearer',
'custom': 'Özel',

View File

@ -28,6 +28,7 @@
"lint:quiet": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet",
"lint:complexity": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --rule 'complexity: [error, {max: 15}]' --quiet",
"type-check": "tsc --noEmit",
"type-check:tsgo": "tsgo --noEmit",
"prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky",
"gen-icons": "node ./app/components/base/icons/script.mjs",
"uglify-embed": "node ./bin/uglify-embed",
@ -110,9 +111,9 @@
"pinyin-pro": "^3.27.0",
"qrcode.react": "^4.2.0",
"qs": "^6.14.0",
"react": "19.1.1",
"react": "19.2.1",
"react-18-input-autosize": "^3.0.0",
"react-dom": "19.1.1",
"react-dom": "19.2.1",
"react-easy-crop": "^5.5.3",
"react-hook-form": "^7.65.0",
"react-hotkeys-hook": "^4.6.2",
@ -153,9 +154,9 @@
"@happy-dom/jest-environment": "^20.0.8",
"@mdx-js/loader": "^3.1.1",
"@mdx-js/react": "^3.1.1",
"@next/bundle-analyzer": "15.5.4",
"@next/eslint-plugin-next": "15.5.4",
"@next/mdx": "15.5.4",
"@next/bundle-analyzer": "15.5.7",
"@next/eslint-plugin-next": "15.5.7",
"@next/mdx": "15.5.7",
"@rgrove/parse-xml": "^4.2.0",
"@storybook/addon-docs": "9.1.13",
"@storybook/addon-links": "9.1.13",
@ -173,8 +174,8 @@
"@types/negotiator": "^0.6.4",
"@types/node": "18.15.0",
"@types/qs": "^6.14.0",
"@types/react": "~19.1.17",
"@types/react-dom": "~19.1.11",
"@types/react": "~19.2.7",
"@types/react-dom": "~19.2.3",
"@types/react-slider": "^1.3.6",
"@types/react-syntax-highlighter": "^15.5.13",
"@types/react-window": "^1.8.8",
@ -200,18 +201,20 @@
"lint-staged": "^15.5.2",
"lodash": "^4.17.21",
"magicast": "^0.3.5",
"nock": "^14.0.10",
"postcss": "^8.5.6",
"react-scan": "^0.4.3",
"sass": "^1.93.2",
"storybook": "9.1.13",
"tailwindcss": "^3.4.18",
"@typescript/native-preview": "^7.0.0-dev",
"ts-node": "^10.9.2",
"typescript": "^5.9.3",
"uglify-js": "^3.19.3"
},
"resolutions": {
"@types/react": "~19.1.17",
"@types/react-dom": "~19.1.11",
"@types/react": "~19.2.7",
"@types/react-dom": "~19.2.3",
"string-width": "~4.2.3",
"@eslint/plugin-kit": "~0.3",
"canvas": "^3.2.0",
@ -282,4 +285,4 @@
"sharp"
]
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -144,7 +144,7 @@ function requiredWebSSOLogin(message?: string, code?: number) {
params.append('message', message)
if (code)
params.append('code', String(code))
globalThis.location.href = `${globalThis.location.origin}${basePath}/${WBB_APP_LOGIN_PATH}?${params.toString()}`
globalThis.location.href = `${globalThis.location.origin}${basePath}${WBB_APP_LOGIN_PATH}?${params.toString()}`
}
export function format(text: string) {

View File

@ -612,12 +612,11 @@ export const usePluginTaskList = (category?: PluginCategoryEnum | string) => {
const taskAllFailed = lastData?.tasks.every(task => task.status === TaskStatus.failed)
if (taskDone && lastData?.tasks.length && !taskAllFailed)
refreshPluginList(category ? { category } as any : undefined, !category)
}, [initialized, isRefetching, data, category, refreshPluginList])
}, [isRefetching])
useEffect(() => {
if (isFetched && !initialized)
setInitialized(true)
}, [isFetched, initialized])
setInitialized(true)
}, [])
const handleRefetch = useCallback(() => {
refetch()

View File

@ -123,7 +123,7 @@ export const useInvalidLastRun = (flowType: FlowType, flowId: string, nodeId: st
// Rerun workflow or change the version of workflow
export const useInvalidAllLastRun = (flowType?: FlowType, flowId?: string) => {
return useInvalid([NAME_SPACE, flowType, 'last-run', flowId])
return useInvalid([...useLastRunKey, flowType, flowId])
}
export const useConversationVarValues = (flowType?: FlowType, flowId?: string) => {

View File

@ -202,6 +202,16 @@ Reserve snapshots for static, deterministic fragments (icons, badges, layout chr
**Note**: Dify is a desktop application. **No need for** responsive/mobile testing.
### 12. Mock API
Use Nock to mock API calls. Example:
```ts
const mockGithubStar = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST).get(GITHUB_PATH).delay(delayMs).reply(status, body)
}
```
## Code Style
### Example Structure