Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-09-13 01:27:37 +08:00
commit b0e815c3c7
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
242 changed files with 10968 additions and 2216 deletions

View File

@ -23,11 +23,37 @@ jobs:
uv run ruff check --fix .
# Format code
uv run ruff format .
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat
run: |
uvx mdformat .
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: oxlint
working-directory: ./web
run: |
pnpx oxlint --fix
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

4
.gitignore vendored
View File

@ -227,3 +227,7 @@ web/public/fallback-*.js
.roo/
api/.env.backup
/clickzetta
# Benchmark
scripts/stress-test/setup/config/
scripts/stress-test/reports/

View File

@ -540,6 +540,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5

View File

@ -477,12 +477,12 @@ def convert_to_agent_apps():
click.echo(f"Converting app: {app.id}")
try:
app.mode = AppMode.AGENT_CHAT.value
app.mode = AppMode.AGENT_CHAT
db.session.commit()
# update conversation mode to agent
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value}
{Conversation.mode: AppMode.AGENT_CHAT}
)
db.session.commit()

View File

@ -31,6 +31,12 @@ class SecurityConfig(BaseSettings):
description="Duration in minutes for which a password reset token remains valid",
default=5,
)
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in minutes for which a email register token remains valid",
default=5,
)
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in minutes for which a change email token remains valid",
default=5,
@ -661,6 +667,11 @@ class AuthConfig(BaseSettings):
default=86400,
)
EMAIL_REGISTER_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying email register after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""

View File

@ -1,4 +1,4 @@
import enum
from enum import Enum
from typing import Literal, Optional
from pydantic import Field, PositiveInt
@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""
class AuthMethod(enum.StrEnum):
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""

View File

@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# workflow default mode
AppMode.WORKFLOW: {
"app": {
"mode": AppMode.WORKFLOW.value,
"mode": AppMode.WORKFLOW,
"enable_site": True,
"enable_api": True,
}
@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# completion default mode
AppMode.COMPLETION: {
"app": {
"mode": AppMode.COMPLETION.value,
"mode": AppMode.COMPLETION,
"enable_site": True,
"enable_api": True,
},
@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# chat default mode
AppMode.CHAT: {
"app": {
"mode": AppMode.CHAT.value,
"mode": AppMode.CHAT,
"enable_site": True,
"enable_api": True,
},
@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"mode": AppMode.ADVANCED_CHAT,
"enable_site": True,
"enable_api": True,
},
@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
# agent-chat default mode
AppMode.AGENT_CHAT: {
"app": {
"mode": AppMode.AGENT_CHAT.value,
"mode": AppMode.AGENT_CHAT,
"enable_site": True,
"enable_api": True,
},

View File

@ -93,6 +93,7 @@ from .auth import (
activate, # pyright: ignore[reportUnusedImport]
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
data_source_oauth, # pyright: ignore[reportUnusedImport]
email_register, # pyright: ignore[reportUnusedImport]
forgot_password, # pyright: ignore[reportUnusedImport]
login, # pyright: ignore[reportUnusedImport]
oauth, # pyright: ignore[reportUnusedImport]

View File

@ -1,12 +1,26 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
@console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect(
api.parser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@api.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@api.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@ -19,6 +33,3 @@ class AdvancedPromptTemplateList(Resource):
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value
@ -9,7 +9,18 @@ from models.model import AppMode
from services.agent_service import AgentService
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource):
@api.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
@api.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@ -23,6 +34,3 @@ class AgentLogApi(Resource):
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")

View File

@ -2,11 +2,11 @@ from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -21,7 +21,23 @@ from libs.login import login_required
from services.annotation_service import AppAnnotationService
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@api.doc("annotation_reply_action")
@api.doc(description="Enable or disable annotation reply for an app")
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@api.expect(
api.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@api.response(200, "Action completed successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -43,7 +59,13 @@ class AnnotationReplyActionApi(Resource):
return result, 200
@console_ns.route("/apps/<uuid:app_id>/annotation-setting")
class AppAnnotationSettingDetailApi(Resource):
@api.doc("get_annotation_setting")
@api.doc(description="Get annotation settings for an app")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Annotation settings retrieved successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -56,7 +78,23 @@ class AppAnnotationSettingDetailApi(Resource):
return result, 200
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
class AppAnnotationSettingUpdateApi(Resource):
@api.doc("update_annotation_setting")
@api.doc(description="Update annotation settings for an app")
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@api.expect(
api.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@api.response(200, "Settings updated successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -75,7 +113,13 @@ class AppAnnotationSettingUpdateApi(Resource):
return result, 200
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
class AnnotationReplyActionStatusApi(Resource):
@api.doc("get_annotation_reply_action_status")
@api.doc(description="Get status of annotation reply action job")
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
@api.response(200, "Job status retrieved successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -99,7 +143,19 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
@console_ns.route("/apps/<uuid:app_id>/annotations")
class AnnotationApi(Resource):
@api.doc("list_annotations")
@api.doc(description="Get annotations for an app with pagination")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@api.response(200, "Annotations retrieved successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -122,6 +178,21 @@ class AnnotationApi(Resource):
}
return response, 200
@api.doc("create_annotation")
@api.doc(description="Create a new annotation for an app")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"CreateAnnotationRequest",
{
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@api.response(201, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -168,7 +239,13 @@ class AnnotationApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@api.doc("export_annotations")
@api.doc(description="Export all annotations for an app")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -182,7 +259,14 @@ class AnnotationExportApi(Resource):
return response, 200
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation")
@api.doc(description="Update or delete an annotation")
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@api.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -214,7 +298,14 @@ class AnnotationUpdateDeleteApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@api.doc("batch_import_annotations")
@api.doc(description="Batch import annotations from CSV file")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Batch import started successfully")
@api.response(403, "Insufficient permissions")
@api.response(400, "No file uploaded or too many files")
@setup_required
@login_required
@account_initialization_required
@ -239,7 +330,13 @@ class AnnotationBatchImportApi(Resource):
return AppAnnotationService.batch_import_app_annotations(app_id, file)
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
class AnnotationBatchImportStatusApi(Resource):
@api.doc("get_batch_import_status")
@api.doc(description="Get status of batch import job")
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
@api.response(200, "Job status retrieved successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -262,7 +359,20 @@ class AnnotationBatchImportStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
class AnnotationHitHistoryListApi(Resource):
@api.doc("list_annotation_hit_histories")
@api.doc(description="Get hit histories for an annotation")
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@api.expect(
api.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
)
@api.response(
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -285,17 +395,3 @@ class AnnotationHitHistoryListApi(Resource):
"page": page,
}
return response
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
api.add_resource(
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
)
api.add_resource(AnnotationApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")

View File

@ -2,12 +2,12 @@ import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, inputs, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -34,7 +34,27 @@ def _validate_description_length(description):
return description
@console_ns.route("/apps")
class AppListApi(Resource):
@api.doc("list_apps")
@api.doc(description="Get list of applications with pagination and filtering")
@api.expect(
api.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@api.response(200, "Success", app_pagination_fields)
@setup_required
@login_required
@account_initialization_required
@ -91,6 +111,24 @@ class AppListApi(Resource):
return marshal(app_pagination, app_pagination_fields), 200
@api.doc("create_app")
@api.doc(description="Create a new application")
@api.expect(
api.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(201, "App created successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@ -124,7 +162,12 @@ class AppListApi(Resource):
return app, 201
@console_ns.route("/apps/<uuid:app_id>")
class AppApi(Resource):
@api.doc("get_app_detail")
@api.doc(description="Get application details")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Success", app_detail_fields_with_site)
@setup_required
@login_required
@account_initialization_required
@ -143,6 +186,26 @@ class AppApi(Resource):
return app_model
@api.doc("update_app")
@api.doc(description="Update application details")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@api.response(200, "App updated successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@ -181,6 +244,11 @@ class AppApi(Resource):
return app_model
@api.doc("delete_app")
@api.doc(description="Delete application")
@api.doc(params={"app_id": "Application ID"})
@api.response(204, "App deleted successfully")
@api.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@ -197,7 +265,25 @@ class AppApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/copy")
class AppCopyApi(Resource):
@api.doc("copy_app")
@api.doc(description="Create a copy of an existing application")
@api.doc(params={"app_id": "Application ID to copy"})
@api.expect(
api.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(201, "App copied successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -239,7 +325,22 @@ class AppCopyApi(Resource):
return app, 201
@console_ns.route("/apps/<uuid:app_id>/export")
class AppExportApi(Resource):
@api.doc("export_app")
@api.doc(description="Export application configuration as DSL")
@api.doc(params={"app_id": "Application ID to export"})
@api.expect(
api.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@api.response(
200,
"App exported successfully",
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
)
@api.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@ -263,7 +364,13 @@ class AppExportApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@api.doc("check_app_name")
@api.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"})
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
@api.response(200, "Name availability checked")
@setup_required
@login_required
@account_initialization_required
@ -284,7 +391,23 @@ class AppNameApi(Resource):
return app_model
@console_ns.route("/apps/<uuid:app_id>/icon")
class AppIconApi(Resource):
@api.doc("update_app_icon")
@api.doc(description="Update application icon")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(200, "Icon updated successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -306,7 +429,18 @@ class AppIconApi(Resource):
return app_model
@console_ns.route("/apps/<uuid:app_id>/site-enable")
class AppSiteStatus(Resource):
@api.doc("update_app_site_status")
@api.doc(description="Enable or disable app site")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@api.response(200, "Site status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -327,7 +461,18 @@ class AppSiteStatus(Resource):
return app_model
@console_ns.route("/apps/<uuid:app_id>/api-enable")
class AppApiStatus(Resource):
@api.doc("update_app_api_status")
@api.doc(description="Enable or disable app API")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@api.response(200, "API status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -348,7 +493,12 @@ class AppApiStatus(Resource):
return app_model
@console_ns.route("/apps/<uuid:app_id>/trace")
class AppTraceApi(Resource):
@api.doc("get_app_trace")
@api.doc(description="Get app tracing configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Trace configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@ -358,6 +508,20 @@ class AppTraceApi(Resource):
return app_trace_config
@api.doc("update_app_trace")
@api.doc(description="Update app tracing configuration")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@api.response(200, "Trace configuration updated successfully")
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -377,14 +541,3 @@ class AppTraceApi(Resource):
)
return {"result": "success"}
api.add_resource(AppListApi, "/apps")
api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")

View File

@ -1,11 +1,11 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -34,7 +34,18 @@ from services.errors.audio import (
logger = logging.getLogger(__name__)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
class ChatMessageAudioApi(Resource):
@api.doc("chat_message_audio_transcript")
@api.doc(description="Transcript audio to text for chat messages")
@api.doc(params={"app_id": "App ID"})
@api.response(
200,
"Audio transcription successful",
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
)
@api.response(400, "Bad request - No audio uploaded or unsupported type")
@api.response(413, "Audio file too large")
@setup_required
@login_required
@account_initialization_required
@ -76,7 +87,24 @@ class ChatMessageAudioApi(Resource):
raise InternalServerError()
@console_ns.route("/apps/<uuid:app_id>/text-to-audio")
class ChatMessageTextApi(Resource):
@api.doc("chat_message_text_to_speech")
@api.doc(description="Convert text to speech for chat messages")
@api.doc(params={"app_id": "App ID"})
@api.expect(
api.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@api.response(200, "Text to speech conversion successful")
@api.response(400, "Bad request - Invalid parameters")
@get_app_model
@setup_required
@login_required
@ -124,7 +152,14 @@ class ChatMessageTextApi(Resource):
raise InternalServerError()
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
class TextModesApi(Resource):
@api.doc("get_text_to_speech_voices")
@api.doc(description="Get available TTS voices for a specific language")
@api.doc(params={"app_id": "App ID"})
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
@api.response(400, "Invalid language parameter")
@get_app_model
@setup_required
@login_required
@ -164,8 +199,3 @@ class TextModesApi(Resource):
except Exception as e:
logger.exception("Failed to handle get request to TextModesApi")
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")

View File

@ -1,11 +1,11 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -38,7 +38,27 @@ logger = logging.getLogger(__name__)
# define completion message api for user
@console_ns.route("/apps/<uuid:app_id>/completion-messages")
class CompletionMessageApi(Resource):
@api.doc("create_completion_message")
@api.doc(description="Generate completion message for debugging")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@api.response(200, "Completion generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@ -86,7 +106,12 @@ class CompletionMessageApi(Resource):
raise InternalServerError()
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
class CompletionMessageStopApi(Resource):
@api.doc("stop_completion_message")
@api.doc(description="Stop a running completion message generation")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@api.response(200, "Task stopped successfully")
@setup_required
@login_required
@account_initialization_required
@ -99,7 +124,29 @@ class CompletionMessageStopApi(Resource):
return {"result": "success"}, 200
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageApi(Resource):
@api.doc("create_chat_message")
@api.doc(description="Generate chat message for debugging")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@api.response(200, "Chat message generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(404, "App or conversation not found")
@setup_required
@login_required
@account_initialization_required
@ -161,7 +208,12 @@ class ChatMessageApi(Resource):
raise InternalServerError()
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
class ChatMessageStopApi(Resource):
@api.doc("stop_chat_message")
@api.doc(description="Stop a running chat message generation")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@api.response(200, "Task stopped successfully")
@setup_required
@login_required
@account_initialization_required
@ -172,9 +224,3 @@ class ChatMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}, 200
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")

View File

@ -8,7 +8,7 @@ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
@ -28,7 +28,29 @@ from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@console_ns.route("/apps/<uuid:app_id>/completion-conversations")
class CompletionConversationApi(Resource):
@api.doc("list_completion_conversations")
@api.doc(description="Get completion conversations with pagination and filtering")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@api.response(200, "Success", conversation_pagination_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -101,7 +123,14 @@ class CompletionConversationApi(Resource):
return conversations
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
class CompletionConversationDetailApi(Resource):
@api.doc("get_completion_conversation")
@api.doc(description="Get completion conversation details with messages")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(200, "Success", conversation_message_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@ -114,6 +143,12 @@ class CompletionConversationDetailApi(Resource):
return _get_conversation(app_model, conversation_id)
@api.doc("delete_completion_conversation")
@api.doc(description="Delete a completion conversation")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(204, "Conversation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@ -133,7 +168,38 @@ class CompletionConversationDetailApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
class ChatConversationApi(Resource):
@api.doc("list_chat_conversations")
@api.doc(description="Get chat conversations with pagination, filtering and summary")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@api.response(200, "Success", conversation_with_summary_pagination_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -241,7 +307,7 @@ class ChatConversationApi(Resource):
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
match args["sort_by"]:
@ -261,7 +327,14 @@ class ChatConversationApi(Resource):
return conversations
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
class ChatConversationDetailApi(Resource):
@api.doc("get_chat_conversation")
@api.doc(description="Get chat conversation details")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(200, "Success", conversation_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@ -274,6 +347,12 @@ class ChatConversationDetailApi(Resource):
return _get_conversation(app_model, conversation_id)
@api.doc("delete_chat_conversation")
@api.doc(description="Delete a chat conversation")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(204, "Conversation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@ -293,12 +372,6 @@ class ChatConversationDetailApi(Resource):
return {"result": "success"}, 204
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
def _get_conversation(app_model, conversation_id):
conversation = (
db.session.query(Conversation)

View File

@ -2,7 +2,7 @@ from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
@ -12,7 +12,17 @@ from models import ConversationVariable
from models.model import AppMode
@console_ns.route("/apps/<uuid:app_id>/conversation-variables")
class ConversationVariablesApi(Resource):
@api.doc("get_conversation_variables")
@api.doc(description="Get conversation variables for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
@setup_required
@login_required
@account_initialization_required
@ -55,6 +65,3 @@ class ConversationVariablesApi(Resource):
for row in rows
],
}
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")

View File

@ -1,9 +1,9 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@ -22,7 +22,23 @@ from models import App
from services.workflow_service import WorkflowService
@console_ns.route("/rule-generate")
class RuleGenerateApi(Resource):
@api.doc("generate_rule_config")
@api.doc(description="Generate rule configuration using LLM")
@api.expect(
api.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@api.response(200, "Rule configuration generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
@ -53,7 +69,26 @@ class RuleGenerateApi(Resource):
return rules
@console_ns.route("/rule-code-generate")
class RuleCodeGenerateApi(Resource):
@api.doc("generate_rule_code")
@api.doc(description="Generate code rules using LLM")
@api.expect(
api.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@api.response(200, "Code rules generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
@ -85,7 +120,22 @@ class RuleCodeGenerateApi(Resource):
return code_result
@console_ns.route("/rule-structured-output-generate")
class RuleStructuredOutputGenerateApi(Resource):
@api.doc("generate_structured_output")
@api.doc(description="Generate structured output rules using LLM")
@api.expect(
api.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@api.response(200, "Structured output generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
@ -114,7 +164,27 @@ class RuleStructuredOutputGenerateApi(Resource):
return structured_output
@console_ns.route("/instruction-generate")
class InstructionGenerateApi(Resource):
@api.doc("generate_instruction")
@api.doc(description="Generate instruction for workflow nodes or general use")
@api.expect(
api.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@api.response(200, "Instruction generated successfully")
@api.response(400, "Invalid request parameters or flow/workflow not found")
@api.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
@ -203,7 +273,21 @@ class InstructionGenerateApi(Resource):
raise CompletionRequestError(e.description)
@console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource):
@api.doc("get_instruction_template")
@api.doc(description="Get instruction generation template")
@api.expect(
api.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@api.response(200, "Template retrieved successfully")
@api.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@ -222,10 +306,3 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args['type']}")
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
api.add_resource(InstructionGenerateApi, "/instruction-generate")
api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")

View File

@ -2,10 +2,10 @@ import json
from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
@ -19,7 +19,12 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive"
@console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource):
@api.doc("get_app_mcp_server")
@api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required
@account_initialization_required
@ -29,6 +34,20 @@ class AppMCPServerController(Resource):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server
@api.doc("create_app_mcp_server")
@api.doc(description="Create MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -59,6 +78,23 @@ class AppMCPServerController(Resource):
db.session.commit()
return server
@api.doc("update_app_mcp_server")
@api.doc(description="Update MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@ -94,7 +130,14 @@ class AppMCPServerController(Resource):
return server
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
class AppMCPServerRefreshController(Resource):
@api.doc("refresh_app_mcp_server")
@api.doc(description="Refresh MCP server configuration and regenerate server code")
@api.doc(params={"server_id": "Server ID"})
@api.response(200, "MCP server refreshed successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@ -113,7 +156,3 @@ class AppMCPServerRefreshController(Resource):
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")

View File

@ -5,7 +5,7 @@ from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@ -37,6 +37,7 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
@ -44,6 +45,17 @@ class ChatMessageListApi(Resource):
"data": fields.List(fields.Nested(message_detail_fields)),
}
@api.doc("list_chat_messages")
@api.doc(description="Get chat messages for a conversation with pagination")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@ -117,7 +129,23 @@ class ChatMessageListApi(Resource):
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
class MessageFeedbackApi(Resource):
@api.doc("create_message_feedback")
@api.doc(description="Create or update message feedback (like/dislike)")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@api.response(200, "Feedback updated successfully")
@api.response(404, "Message not found")
@api.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@ -162,7 +190,24 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@ -186,7 +231,16 @@ class MessageAnnotationApi(Resource):
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count")
@api.doc(description="Get count of message annotations for the app")
@api.doc(params={"app_id": "Application ID"})
@api.response(
200,
"Annotation count retrieved successfully",
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
)
@get_app_model
@setup_required
@login_required
@ -197,7 +251,17 @@ class MessageAnnotationCountApi(Resource):
return {"count": count}
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(Resource):
@api.doc("get_message_suggested_questions")
@api.doc(description="Get suggested questions for a message")
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(
200,
"Suggested questions retrieved successfully",
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
)
@api.response(404, "Message or conversation not found")
@setup_required
@login_required
@account_initialization_required
@ -230,7 +294,13 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
class MessageApi(Resource):
@api.doc("get_message")
@api.doc(description="Get message details by ID")
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found")
@setup_required
@login_required
@account_initialization_required
@ -245,11 +315,3 @@ class MessageApi(Resource):
raise NotFound("Message Not Exists.")
return message
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")

View File

@ -2,10 +2,11 @@ import json
from typing import cast
from flask import request
from flask_restx import Resource
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.agent.entities import AgentToolEntity
@ -13,13 +14,39 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import current_user, login_required
from libs.login import login_required
from models.account import Account
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource):
@api.doc("update_app_model_config")
@api.doc(description="Update application model configuration")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"ModelConfigRequest",
{
"provider": fields.String(description="Model provider"),
"model": fields.String(description="Model name"),
"configs": fields.Raw(description="Model configuration parameters"),
"opening_statement": fields.String(description="Opening statement"),
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
"more_like_this": fields.Raw(description="More like this configuration"),
"speech_to_text": fields.Raw(description="Speech to text configuration"),
"text_to_speech": fields.Raw(description="Text to speech configuration"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"tools": fields.List(fields.Raw(), description="Available tools"),
"dataset_configs": fields.Raw(description="Dataset configurations"),
"agent_mode": fields.Raw(description="Agent mode configuration"),
},
)
)
@api.response(200, "Model configuration updated successfully")
@api.response(400, "Invalid configuration")
@api.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@ -47,7 +74,7 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
@ -150,6 +177,3 @@ class ModelConfigResource(Resource):
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return {"result": "success"}
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")

View File

@ -1,18 +1,31 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.ops_service import OpsService
@console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource):
"""
Manage trace app configurations
"""
@api.doc("get_trace_app_config")
@api.doc(description="Get tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@api.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
)
@api.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@ -29,6 +42,22 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
@api.doc("create_trace_app_config")
@api.doc(description="Create a new tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@api.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
)
@api.response(400, "Invalid request parameters or configuration already exists")
@setup_required
@login_required
@account_initialization_required
@ -51,6 +80,20 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
@api.doc("update_trace_app_config")
@api.doc(description="Update an existing tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@api.response(400, "Invalid request parameters or configuration not found")
@setup_required
@login_required
@account_initialization_required
@ -71,6 +114,16 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
@api.doc("delete_trace_app_config")
@api.doc(description="Delete an existing tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@api.response(204, "Tracing configuration deleted successfully")
@api.response(400, "Invalid request parameters or configuration not found")
@setup_required
@login_required
@account_initialization_required
@ -87,6 +140,3 @@ class TraceAppConfigApi(Resource):
return {"result": "success"}, 204
except Exception as e:
raise BadRequest(str(e))
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")

View File

@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
@ -36,7 +36,39 @@ def parse_app_site_args():
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource):
@api.doc("update_app_site")
@api.doc(description="Update application site configuration")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@api.response(200, "Site configuration updated successfully", app_site_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@ -84,7 +116,14 @@ class AppSite(Resource):
return site
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
class AppSiteAccessTokenReset(Resource):
@api.doc("reset_app_site_access_token")
@api.doc(description="Reset access token for application site")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Access token reset successfully", app_site_fields)
@api.response(403, "Insufficient permissions (admin/owner required)")
@api.response(404, "App or site not found")
@setup_required
@login_required
@account_initialization_required
@ -108,7 +147,3 @@ class AppSiteAccessTokenReset(Resource):
db.session.commit()
return site
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")

View File

@ -5,9 +5,9 @@ import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
@ -17,7 +17,21 @@ from libs.login import login_required
from models import AppMode, Message
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource):
@api.doc("get_daily_message_statistics")
@api.doc(description="Get daily message statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Daily message statistics retrieved successfully",
fields.List(fields.Raw(description="Daily message count data")),
)
@get_app_model
@setup_required
@login_required
@ -74,7 +88,21 @@ WHERE
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Daily conversation statistics retrieved successfully",
fields.List(fields.Raw(description="Daily conversation count data")),
)
@get_app_model
@setup_required
@login_required
@ -126,7 +154,21 @@ class DailyConversationStatistic(Resource):
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Daily terminal statistics retrieved successfully",
fields.List(fields.Raw(description="Daily terminal count data")),
)
@get_app_model
@setup_required
@login_required
@ -183,7 +225,21 @@ WHERE
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Daily token cost statistics retrieved successfully",
fields.List(fields.Raw(description="Daily token cost data")),
)
@get_app_model
@setup_required
@login_required
@ -243,7 +299,21 @@ WHERE
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Average session interaction statistics retrieved successfully",
fields.List(fields.Raw(description="Average session interaction data")),
)
@setup_required
@login_required
@account_initialization_required
@ -319,7 +389,21 @@ ORDER BY
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"User satisfaction rate statistics retrieved successfully",
fields.List(fields.Raw(description="User satisfaction rate data")),
)
@get_app_model
@setup_required
@login_required
@ -385,7 +469,21 @@ WHERE
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Average response time statistics retrieved successfully",
fields.List(fields.Raw(description="Average response time data")),
)
@setup_required
@login_required
@account_initialization_required
@ -442,7 +540,21 @@ WHERE
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200,
"Tokens per second statistics retrieved successfully",
fields.List(fields.Raw(description="Tokens per second data")),
)
@get_app_model
@setup_required
@login_required
@ -500,13 +612,3 @@ WHERE
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
return jsonify({"data": response_data})
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")

View File

@ -721,7 +721,6 @@ class WorkflowByIdApi(Resource):
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data
update_data = {}

View File

@ -0,0 +1,155 @@
from flask import request
from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailCodeError,
EmailRegisterLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
class EmailRegisterSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
if args["language"] in languages:
language = args["language"]
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
return {"result": "success", "data": token}
class EmailRegisterCheckApi(Resource):
@setup_required
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
user_email = args["email"]
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class EmailRegisterResetApi(Resource):
@setup_required
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
# Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"])
if not register_data:
raise InvalidTokenError()
# Must use token in reset phase
if register_data.get("phase", "") != "register":
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"])
email = register_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args["password_confirm"])
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email, password) -> Account | None:
# Create new account if allowed
account = None
try:
account = AccountService.create_account_and_tenant(
email=email,
name=email,
password=password,
interface_language=languages[0],
)
except AccountRegisterError:
raise AccountInFreezeError()
return account
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
api.add_resource(EmailRegisterResetApi, "/email-register")

View File

@ -27,21 +27,43 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
description = "Too many password reset emails have been sent. Please try again in 1 minute."
description = "Too many password reset emails have been sent. Please try again in {minutes} minutes."
code = 429
def __init__(self, minutes: int = 1):
description = self.description.format(minutes=int(minutes)) if self.description else None
super().__init__(description=description)
class EmailRegisterRateLimitExceededError(BaseHTTPException):
error_code = "email_register_rate_limit_exceeded"
description = "Too many email register emails have been sent. Please try again in {minutes} minutes."
code = 429
def __init__(self, minutes: int = 1):
description = self.description.format(minutes=int(minutes)) if self.description else None
super().__init__(description=description)
class EmailChangeRateLimitExceededError(BaseHTTPException):
error_code = "email_change_rate_limit_exceeded"
description = "Too many email change emails have been sent. Please try again in 1 minute."
description = "Too many email change emails have been sent. Please try again in {minutes} minutes."
code = 429
def __init__(self, minutes: int = 1):
description = self.description.format(minutes=int(minutes)) if self.description else None
super().__init__(description=description)
class OwnerTransferRateLimitExceededError(BaseHTTPException):
error_code = "owner_transfer_rate_limit_exceeded"
description = "Too many owner transfer emails have been sent. Please try again in 1 minute."
description = "Too many owner transfer emails have been sent. Please try again in {minutes} minutes."
code = 429
def __init__(self, minutes: int = 1):
description = self.description.format(minutes=int(minutes)) if self.description else None
super().__init__(description=description)
class EmailCodeError(BaseHTTPException):
error_code = "email_code_error"
@ -69,15 +91,23 @@ class EmailPasswordLoginLimitError(BaseHTTPException):
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
error_code = "email_code_login_rate_limit_exceeded"
description = "Too many login emails have been sent. Please try again in 5 minutes."
description = "Too many login emails have been sent. Please try again in {minutes} minutes."
code = 429
def __init__(self, minutes: int = 5):
description = self.description.format(minutes=int(minutes)) if self.description else None
super().__init__(description=description)
class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
error_code = "email_code_account_deletion_rate_limit_exceeded"
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
description = "Too many account deletion emails have been sent. Please try again in {minutes} minutes."
code = 429
def __init__(self, minutes: int = 5):
description = self.description.format(minutes=int(minutes)) if self.description else None
super().__init__(description=description)
class EmailPasswordResetLimitError(BaseHTTPException):
error_code = "email_password_reset_limit"
@ -85,6 +115,12 @@ class EmailPasswordResetLimitError(BaseHTTPException):
code = 429
class EmailRegisterLimitError(BaseHTTPException):
error_code = "email_register_limit"
description = "Too many failed email register attempts. Please try again in 24 hours."
code = 429
class EmailChangeLimitError(BaseHTTPException):
error_code = "email_change_limit"
description = "Too many failed email change attempts. Please try again in 24 hours."

View File

@ -6,7 +6,6 @@ from flask_restx import Resource, fields, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages
from controllers.console import api, console_ns
from controllers.console.auth.error import (
EmailCodeError,
@ -15,7 +14,7 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
@ -23,8 +22,6 @@ from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@ -73,15 +70,13 @@ class ForgotPasswordSendEmailApi(Resource):
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
token = AccountService.send_reset_password_email(
account=account,
email=args["email"],
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
return {"result": "success", "data": token}
@ -207,7 +202,7 @@ class ForgotPasswordResetApi(Resource):
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
self._create_new_account(email, args["password_confirm"])
raise AccountNotFound()
return {"result": "success"}
@ -227,18 +222,7 @@ class ForgotPasswordResetApi(Resource):
account.current_tenant = tenant
tenant_was_created.send(tenant)
def _create_new_account(self, email, password):
# Create new account if allowed
try:
AccountService.create_account_and_tenant(
email=email,
name=email,
password=password,
interface_language=languages[0],
)
except WorkSpaceNotAllowedCreateError:
pass
except WorkspacesLimitExceededError:
pass
except AccountRegisterError:
raise AccountInFreezeError()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -26,7 +26,6 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
@ -44,10 +43,9 @@ class LoginApi(Resource):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("password", type=str, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
@ -61,11 +59,6 @@ class LoginApi(Resource):
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
if invitation:
data = invitation.get("data", {})
@ -80,12 +73,6 @@ class LoginApi(Resource):
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
raise AuthenticationFailedError()
except services.errors.account.AccountNotFoundError:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise AccountNotFound()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -133,13 +120,12 @@ class ResetPasswordSendEmailApi(Resource):
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
else:
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, language=language)
token = AccountService.send_reset_password_email(
email=args["email"],
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
return {"result": "success", "data": token}

View File

@ -18,6 +18,7 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
@ -183,7 +184,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
"30 days and is temporarily unavailable for new account registration"
)
)
else:
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

View File

@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource):
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -242,6 +242,19 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated
def email_register_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if features.is_allow_register:
return view(*args, **kwargs)
# otherwise, return 403
abort(403)
return decorated
def enable_change_email(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):

View File

@ -86,7 +86,7 @@ class PluginUploadFileApi(Resource):
filename=filename,
mimetype=mimetype,
tenant_id=tenant_id,
user_id=user_id,
user_id=user.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,

View File

@ -8,11 +8,10 @@ from flask_restx import reqparse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db
from libs.login import current_user
from models.account import Tenant
from models.model import EndUser
from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
@ -28,7 +27,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
try:
with Session(db.engine) as session:
if not user_id:
user_id = DEFAULT_SERVICE_API_USER_ID
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_model = (
session.query(EndUser)
@ -42,7 +41,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
session_id=user_id,
)
session.add(user_model)
@ -73,7 +72,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
raise ValueError("tenant_id is required")
if not user_id:
user_id = DEFAULT_SERVICE_API_USER_ID
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
try:
tenant_model = (

View File

@ -150,7 +150,7 @@ class MCPAppApi(Resource):
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
"""Get and convert user input form"""
# Get raw user input form based on app mode
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if not app.workflow:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)

View File

@ -33,7 +33,6 @@ from .dataset import (
hit_testing, # pyright: ignore[reportUnusedImport]
metadata, # pyright: ignore[reportUnusedImport]
segment, # pyright: ignore[reportUnusedImport]
upload_file, # pyright: ignore[reportUnusedImport]
)
from .workspace import models # pyright: ignore[reportUnusedImport]

View File

@ -29,7 +29,7 @@ class AppParameterApi(Resource):
Returns the input form parameters and configuration for the application.
"""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -340,6 +340,9 @@ class DatasetApi(DatasetApiResource):
else:
data["embedding_available"] = True
# force update search method to keyword_search if indexing_technique is economic
data["retrieval_model_dict"]["search_method"] = "keyword_search"
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})

View File

@ -1,65 +0,0 @@
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import (
DatasetApiResource,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
class UploadFileApi(DatasetApiResource):
@service_api_ns.doc("get_upload_file")
@service_api_ns.doc(description="Get upload file information and download URL")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Upload file information retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Dataset, document, or upload file not found",
}
)
def get(self, tenant_id, dataset_id, document_id):
"""Get upload file information and download URL.
Returns information about an uploaded file including its download URL.
"""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200

View File

@ -13,14 +13,13 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
from services.feature_service import FeatureService
P = ParamSpec("P")
@ -273,7 +272,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = DEFAULT_SERVICE_API_USER_ID
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
@ -292,7 +291,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
session_id=user_id,
)
session.add(end_user)

View File

@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -1,4 +1,4 @@
import enum
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity):
class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
class AgentStrategyParameterType(StrEnum):
"""
Keep all the types from PluginParameterType
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity):
pass
class AgentFeature(enum.StrEnum):
class AgentFeature(StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""

View File

@ -70,7 +70,7 @@ class PromptTemplateConfigManager:
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
@ -90,7 +90,7 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
Prompt Template Entity.
"""
class PromptType(Enum):
class PromptType(StrEnum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = auto()
ADVANCED = auto()
@classmethod
def value_of(cls, value: str):
@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
class RetrieveStrategy(StrEnum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = auto()
MULTIPLE = auto()
@classmethod
def value_of(cls, value: str):
@ -293,12 +293,12 @@ class AppConfig(BaseModel):
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class EasyUIBasedAppModelConfigFrom(Enum):
class EasyUIBasedAppModelConfigFrom(StrEnum):
"""
App Model Config From.
"""
ARGS = "args"
ARGS = auto()
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
@ -510,15 +510,15 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent entity
"""
class StopBy(Enum):
class StopBy(StrEnum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
USER_MANUAL = auto()
ANNOTATION_REPLY = auto()
OUTPUT_MODERATION = auto()
INPUT_MODERATION = auto()
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@ -50,35 +50,37 @@ class WorkflowTaskState(TaskState):
answer: str = ""
class StreamEvent(Enum):
class StreamEvent(StrEnum):
"""
Stream event
"""
PING = "ping"
ERROR = "error"
MESSAGE = "message"
MESSAGE_END = "message_end"
TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
LOOP_STARTED = "loop_started"
LOOP_NEXT = "loop_next"
LOOP_COMPLETED = "loop_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
PING = auto()
ERROR = auto()
MESSAGE = auto()
MESSAGE_END = auto()
TTS_MESSAGE = auto()
TTS_MESSAGE_END = auto()
MESSAGE_FILE = auto()
MESSAGE_REPLACE = auto()
AGENT_THOUGHT = auto()
AGENT_MESSAGE = auto()
WORKFLOW_STARTED = auto()
WORKFLOW_FINISHED = auto()
NODE_STARTED = auto()
NODE_FINISHED = auto()
NODE_RETRY = auto()
PARALLEL_BRANCH_STARTED = auto()
PARALLEL_BRANCH_FINISHED = auto()
ITERATION_STARTED = auto()
ITERATION_NEXT = auto()
ITERATION_COMPLETED = auto()
LOOP_STARTED = auto()
LOOP_NEXT = auto()
LOOP_COMPLETED = auto()
TEXT_CHUNK = auto()
TEXT_REPLACE = auto()
AGENT_LOG = auto()
class StreamResponse(BaseModel):

View File

@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(

View File

@ -92,7 +92,7 @@ class MessageCycleManager:
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return

View File

@ -1,8 +1,8 @@
from enum import Enum
from enum import StrEnum, auto
class PlanningStrategy(Enum):
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"
class PlanningStrategy(StrEnum):
ROUTER = auto()
REACT_ROUTER = auto()
REACT = auto()
FUNCTION_CALL = auto()

View File

@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum, auto
class EmbeddingInputType(Enum):
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = "document"
QUERY = "query"
DOCUMENT = auto()
QUERY = auto()

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from typing import Optional
from pydantic import BaseModel, ConfigDict
@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ProviderEntity
class ModelStatus(Enum):
class ModelStatus(StrEnum):
"""
Enum class for model status.
"""
ACTIVE = "active"
ACTIVE = auto()
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
DISABLED = auto()
CREDENTIAL_REMOVED = "credential-removed"

View File

@ -1,20 +1,20 @@
from enum import StrEnum
from enum import StrEnum, auto
class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SELECT = auto()
STRING = auto()
NUMBER = auto()
FILE = auto()
FILES = auto()
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
BOOLEAN = auto()
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
ANY = "any"
ANY = auto()
# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):
# TOOL_SELECTOR = "tool-selector"
# MCP object and array type parameters
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()
class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
ALL = auto()
CHAT = auto()
WORKFLOW = auto()
COMPLETION = auto()
class ModelSelectorScope(StrEnum):
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
RERANK = auto()
TTS = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
VISION = auto()
class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"
ALL = auto()
CUSTOM = auto()
BUILTIN = auto()
WORKFLOW = auto()

View File

@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum, auto
from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field
@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = "free"
FREE = auto()
"""third-party free quota"""
TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
@ -31,20 +31,20 @@ class ProviderQuotaType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class QuotaUnit(StrEnum):
TIMES = auto()
TOKENS = auto()
CREDITS = auto()
class SystemConfigurationStatus(Enum):
class SystemConfigurationStatus(StrEnum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
ACTIVE = auto()
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
UNSUPPORTED = auto()
class RestrictModel(BaseModel):
@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel):
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
class Type(StrEnum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT
TEXT_INPUT = CommonParameterType.TEXT_INPUT
SELECT = CommonParameterType.SELECT
BOOLEAN = CommonParameterType.BOOLEAN
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":

View File

@ -1,8 +1,8 @@
import enum
import importlib.util
import json
import logging
import os
from enum import StrEnum, auto
from pathlib import Path
from typing import Any, Optional
@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map
logger = logging.getLogger(__name__)
class ExtensionModule(enum.Enum):
MODERATION = "moderation"
EXTERNAL_DATA_TOOL = "external_data_tool"
class ExtensionModule(StrEnum):
MODERATION = auto()
EXTERNAL_DATA_TOOL = auto()
class ModuleExtension(BaseModel):

View File

@ -9,7 +9,3 @@ FILE_MODEL_IDENTITY = "__dify__file__"
def maybe_file_object(o: Any) -> bool:
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
# The default user ID for service API calls.
DEFAULT_SERVICE_API_USER_ID = "DEFAULT-USER"

View File

@ -5,7 +5,6 @@ import os
import time
from configs import dify_config
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
def get_signed_file_url(upload_file_id: str) -> str:
@ -25,10 +24,6 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
# Plugin access should use internal URL for Docker network communication
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
url = f"{base_url}/files/upload/for-plugin"
if user_id is None:
user_id = DEFAULT_SERVICE_API_USER_ID
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = dify_config.SECRET_KEY.encode()
@ -40,11 +35,8 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
def verify_plugin_file_signature(
*, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str
*, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str
) -> bool:
if user_id is None:
user_id = DEFAULT_SERVICE_API_USER_ID
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()

View File

@ -1,12 +1,12 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCacheType(StrEnum):
PROVIDER = "provider"
MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""

View File

@ -1,12 +1,14 @@
import os
from collections import OrderedDict
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
from core.tools.utils.yaml_utils import load_yaml_file_cached
@lru_cache(maxsize=128)
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
@ -14,12 +16,17 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
# FIXME(-LAN-): Cache position maps to prevent file descriptor exhaustion during high-load benchmarks
position_file_path = os.path.join(folder_path, file_name)
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
try:
yaml_content = load_yaml_file_cached(file_path=position_file_path)
except Exception:
yaml_content = []
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
return {name: index for index, name in enumerate(positions)}
@lru_cache(maxsize=128)
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
@ -35,20 +42,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.

View File

@ -1,12 +1,12 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
class ToolParameterCacheType(StrEnum):
PARAMETER = "tool_parameter"
@ -15,7 +15,7 @@ class ToolParameterCache:
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = (
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f":identity_id:{identity_id}"
)

View File

@ -142,7 +142,7 @@ def handle_call_tool(
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
streaming=app.mode == AppMode.AGENT_CHAT,
)
answer = extract_answer_from_response(app, response)
@ -157,7 +157,7 @@ def build_parameter_schema(
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
return {
"type": "object",
"properties": parameters,
@ -175,9 +175,9 @@ def build_parameter_schema(
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))

View File

@ -1,20 +1,20 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(Enum):
class PromptMessageRole(StrEnum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
Enum class for prompt message content type.
"""
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
TEXT = auto()
IMAGE = auto()
AUDIO = auto()
VIDEO = auto()
DOCUMENT = auto()
class PromptMessageContent(ABC, BaseModel):
@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
class DETAIL(StrEnum):
LOW = "low"
HIGH = "high"
LOW = auto()
HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW

View File

@ -1,5 +1,5 @@
from decimal import Decimal
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, model_validator
@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
from core.model_runtime.entities.common_entities import I18nObject
class ModelType(Enum):
class ModelType(StrEnum):
"""
Enum class for model type.
"""
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
RERANK = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
@ -26,17 +26,17 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type in {"text-generation", cls.LLM.value}:
if origin_model_type in {"text-generation", cls.LLM}:
return cls.LLM
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
return cls.TEXT_EMBEDDING
elif origin_model_type in {"reranking", cls.RERANK.value}:
elif origin_model_type in {"reranking", cls.RERANK}:
return cls.RERANK
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS.value}:
elif origin_model_type in {"tts", cls.TTS}:
return cls.TTS
elif origin_model_type == cls.MODERATION.value:
elif origin_model_type == cls.MODERATION:
return cls.MODERATION
else:
raise ValueError(f"invalid origin model type {origin_model_type}")
@ -63,7 +63,7 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
class FetchFrom(StrEnum):
"""
Enum class for fetch from.
"""
@ -72,7 +72,7 @@ class FetchFrom(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"
class ModelFeature(Enum):
class ModelFeature(StrEnum):
"""
Enum class for llm feature.
"""
@ -80,11 +80,11 @@ class ModelFeature(Enum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()
STRUCTURED_OUTPUT = "structured-output"
@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
TOP_K = "top_k"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
TEMPERATURE = auto()
TOP_P = auto()
TOP_K = auto()
PRESENCE_PENALTY = auto()
FREQUENCY_PENALTY = auto()
MAX_TOKENS = auto()
RESPONSE_FORMAT = auto()
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
raise ValueError(f"invalid parameter name {value}")
class ParameterType(Enum):
class ParameterType(StrEnum):
"""
Enum class for parameter type.
"""
FLOAT = "float"
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
FLOAT = auto()
INT = auto()
STRING = auto()
BOOLEAN = auto()
TEXT = auto()
class ModelPropertyKey(Enum):
class ModelPropertyKey(StrEnum):
"""
Enum class for model property key.
"""
MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
FILE_UPLOAD_LIMIT = "file_upload_limit"
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
MODE = auto()
CONTEXT_SIZE = auto()
MAX_CHUNKS = auto()
FILE_UPLOAD_LIMIT = auto()
SUPPORTED_FILE_EXTENSIONS = auto()
MAX_CHARACTERS_PER_CHUNK = auto()
DEFAULT_VOICE = auto()
VOICES = auto()
WORD_LIMIT = auto()
AUDIO_TYPE = auto()
MAX_WORKERS = auto()
class ProviderModel(BaseModel):
@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
pass
class PriceType(Enum):
class PriceType(StrEnum):
"""
Enum class for price type.
"""
INPUT = "input"
OUTPUT = "output"
INPUT = auto()
OUTPUT = auto()
class PriceInfo(BaseModel):

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import Enum, StrEnum, auto
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"
class FormType(Enum):
class FormType(StrEnum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
RADIO = "radio"
SWITCH = "switch"
SELECT = auto()
RADIO = auto()
SWITCH = auto()
class FormShowOnObject(BaseModel):

View File

@ -48,7 +48,7 @@ class TextEmbeddingModel(AIModel):
model=model,
credentials=credentials,
texts=texts,
input_type=input_type.value,
input_type=input_type,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -1,14 +1,10 @@
import hashlib
import logging
import os
from collections.abc import Sequence
from threading import Lock
from typing import Optional
from pydantic import BaseModel
import contexts
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
@ -26,50 +22,22 @@ from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class ModelProviderExtension(BaseModel):
plugin_model_provider_entity: PluginModelProviderEntity
position: Optional[int] = None
class ModelProviderFactory:
provider_position_map: dict[str, int]
def __init__(self, tenant_id: str) -> None:
def __init__(self, tenant_id: str):
from core.plugin.impl.model import PluginModelClient
self.provider_position_map = {}
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
if not self.provider_position_map:
# get the path of current classes
current_path = os.path.abspath(__file__)
model_providers_path = os.path.dirname(current_path)
# get _position.yaml file path
self.provider_position_map = get_provider_position_map(model_providers_path)
def get_providers(self) -> Sequence[ProviderEntity]:
"""
Get all providers
:return: list of providers
"""
# Fetch plugin model providers
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
# The plugin server should return providers in the desired order
plugin_providers = self.get_plugin_model_providers()
# Convert PluginModelProviderEntity to ModelProviderExtension
model_provider_extensions = []
for provider in plugin_providers:
model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider))
sorted_extensions = sort_to_dict_by_position_map(
position_map=self.provider_position_map,
data=model_provider_extensions,
name_func=lambda x: x.plugin_model_provider_entity.declaration.provider,
)
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
"""

View File

@ -18,7 +18,7 @@ from pydantic_core import Url
from pydantic_extra_types.color import Color
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)
@ -100,7 +100,7 @@ def jsonable_encoder(
exclude_none: bool = False,
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
):
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from enum import StrEnum, auto
from typing import Optional
from pydantic import BaseModel, Field
@ -7,9 +7,9 @@ from pydantic import BaseModel, Field
from core.extension.extensible import Extensible, ExtensionModule
class ModerationAction(Enum):
DIRECT_OUTPUT = "direct_output"
OVERRIDDEN = "overridden"
class ModerationAction(StrEnum):
DIRECT_OUTPUT = auto()
OVERRIDDEN = auto()
class ModerationInputsResult(BaseModel):

View File

@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
class GenAISpanKind(Enum):
class GenAISpanKind(StrEnum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"

View File

@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
conversation_id = conversation_id or ""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
if not query:
raise ValueError("missing query")
@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke chat app
"""
if app.mode == AppMode.ADVANCED_CHAT.value:
if app.mode == AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.AGENT_CHAT.value:
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.CHAT.value:
elif app.mode == AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,

View File

@ -1,5 +1,5 @@
import enum
import json
from enum import StrEnum, auto
from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator
@ -24,44 +24,44 @@ class PluginParameterOption(BaseModel):
return value
class PluginParameterType(enum.StrEnum):
class PluginParameterType(StrEnum):
"""
all available parameter types
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
# MCP object and array type parameters
ARRAY = CommonParameterType.ARRAY.value
OBJECT = CommonParameterType.OBJECT.value
ARRAY = CommonParameterType.ARRAY
OBJECT = CommonParameterType.OBJECT
class MCPServerParameterType(enum.StrEnum):
class MCPServerParameterType(StrEnum):
"""
MCP server got complex parameter types
"""
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()
class PluginParameterAutoGenerate(BaseModel):
class Type(enum.StrEnum):
PROMPT_INSTRUCTION = "prompt_instruction"
class Type(StrEnum):
PROMPT_INSTRUCTION = auto()
type: Type
@ -92,7 +92,7 @@ class PluginParameter(BaseModel):
return v
def as_normal_type(typ: enum.StrEnum):
def as_normal_type(typ: StrEnum):
if typ.value in {
PluginParameterType.SECRET_INPUT,
PluginParameterType.SELECT,
@ -101,7 +101,7 @@ def as_normal_type(typ: enum.StrEnum):
return typ.value
def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
def cast_parameter_value(typ: StrEnum, value: Any, /):
try:
match typ.value:
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
@ -189,7 +189,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any):
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
"""
init frontend parameter by rule
"""

View File

@ -1,6 +1,6 @@
import datetime
import enum
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Optional
from packaging.version import InvalidVersion, Version
@ -14,11 +14,11 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity
class PluginInstallationSource(enum.StrEnum):
Github = "github"
Marketplace = "marketplace"
Package = "package"
Remote = "remote"
class PluginInstallationSource(StrEnum):
Github = auto()
Marketplace = auto()
Package = auto()
Remote = auto()
class PluginResourceRequirements(BaseModel):
@ -56,10 +56,10 @@ class PluginResourceRequirements(BaseModel):
permission: Optional[Permission] = Field(default=None)
class PluginCategory(enum.StrEnum):
Tool = "tool"
Model = "model"
Extension = "extension"
class PluginCategory(StrEnum):
Tool = auto()
Model = auto()
Extension = auto()
AgentStrategy = "agent-strategy"
@ -155,10 +155,10 @@ class PluginEntity(PluginInstallation):
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value
Marketplace = PluginInstallationSource.Marketplace.value
Package = PluginInstallationSource.Package.value
class Type(StrEnum):
Github = PluginInstallationSource.Github
Marketplace = PluginInstallationSource.Marketplace
Package = PluginInstallationSource.Package
class Github(BaseModel):
repo: str

View File

@ -1,7 +1,7 @@
import enum
import json
import os
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.app_config.entities import PromptTemplateEntity
@ -25,9 +25,9 @@ if TYPE_CHECKING:
from core.file.models import File
class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
class ModelMode(StrEnum):
COMPLETION = auto()
CHAT = auto()
prompt_file_contents: dict[str, Any] = {}

View File

@ -1,13 +1,13 @@
from enum import Enum
from enum import StrEnum, auto
class Field(Enum):
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
SPARSE_VECTOR = auto()
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@ -1,7 +1,7 @@
import json
import logging
import uuid
from enum import Enum
from enum import StrEnum
from typing import Any
from clickhouse_connect import get_client
@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel):
fts_params: str
class SortOrder(Enum):
class SortOrder(StrEnum):
ASC = "ASC"
DESC = "DESC"

View File

@ -1,7 +1,7 @@
from enum import Enum
from enum import StrEnum
class DatasourceType(Enum):
class DatasourceType(StrEnum):
FILE = "upload_file"
NOTION = "notion_import"
WEBSITE = "website_crawl"

View File

@ -1,15 +1,15 @@
from enum import Enum, StrEnum
from enum import StrEnum, auto
class BuiltInField(StrEnum):
document_name = "document_name"
uploader = "uploader"
upload_date = "upload_date"
last_update_date = "last_update_date"
source = "source"
document_name = auto()
uploader = auto()
upload_date = auto()
last_update_date = auto()
source = auto()
class MetadataDataSource(Enum):
class MetadataDataSource(StrEnum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"

View File

@ -113,21 +113,33 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# node_ids is segment's node_ids
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
vector = Vector(dataset)
if node_ids:
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
# Use precomputed child_node_ids if available (to avoid race conditions)
if precomputed_child_node_ids is not None:
child_node_ids = precomputed_child_node_ids
else:
# Fallback to original query (may fail if segments are already deleted)
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
# Delete from vector index
if child_node_ids:
vector.delete_by_ids(child_node_ids)
# Delete from database
if delete_child_chunks and child_node_ids:
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete(synchronize_session=False)

View File

@ -18,7 +18,7 @@ from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
)
from core.tools.utils.yaml_utils import load_yaml_file
from core.tools.utils.yaml_utils import load_yaml_file_cached
class BuiltinToolProviderController(ToolProviderController):
@ -31,7 +31,7 @@ class BuiltinToolProviderController(ToolProviderController):
provider = self.__class__.__module__.split(".")[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
try:
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
provider_yaml = load_yaml_file_cached(yaml_path)
except Exception as e:
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
@ -71,7 +71,7 @@ class BuiltinToolProviderController(ToolProviderController):
for tool_file in tool_files:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
tool = load_yaml_file_cached(path.join(tool_path, tool_file))
# get tool class, import the module
assistant_tool_class: type = load_single_subclass_from_source(

View File

@ -1,8 +1,7 @@
import base64
import contextlib
import enum
from collections.abc import Mapping
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
class ToolLabelEnum(Enum):
SEARCH = "search"
IMAGE = "image"
VIDEOS = "videos"
WEATHER = "weather"
FINANCE = "finance"
DESIGN = "design"
TRAVEL = "travel"
SOCIAL = "social"
NEWS = "news"
MEDICAL = "medical"
PRODUCTIVITY = "productivity"
EDUCATION = "education"
BUSINESS = "business"
ENTERTAINMENT = "entertainment"
UTILITIES = "utilities"
OTHER = "other"
class ToolLabelEnum(StrEnum):
SEARCH = auto()
IMAGE = auto()
VIDEOS = auto()
WEATHER = auto()
FINANCE = auto()
DESIGN = auto()
TRAVEL = auto()
SOCIAL = auto()
NEWS = auto()
MEDICAL = auto()
PRODUCTIVITY = auto()
EDUCATION = auto()
BUSINESS = auto()
ENTERTAINMENT = auto()
UTILITIES = auto()
OTHER = auto()
class ToolProviderType(enum.StrEnum):
class ToolProviderType(StrEnum):
"""
Enum class for tool provider
"""
PLUGIN = "plugin"
PLUGIN = auto()
BUILT_IN = "builtin"
WORKFLOW = "workflow"
API = "api"
APP = "app"
WORKFLOW = auto()
API = auto()
APP = auto()
DATASET_RETRIEVAL = "dataset-retrieval"
MCP = "mcp"
MCP = auto()
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum):
raise ValueError(f"invalid mode value {value}")
class ApiProviderSchemaType(Enum):
class ApiProviderSchemaType(StrEnum):
"""
Enum class for api provider schema type.
"""
OPENAPI = "openapi"
SWAGGER = "swagger"
OPENAI_PLUGIN = "openai_plugin"
OPENAI_ACTIONS = "openai_actions"
OPENAPI = auto()
SWAGGER = auto()
OPENAI_PLUGIN = auto()
OPENAI_ACTIONS = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum):
raise ValueError(f"invalid mode value {value}")
class ApiProviderAuthType(Enum):
class ApiProviderAuthType(StrEnum):
"""
Enum class for api provider auth type.
"""
NONE = "none"
API_KEY_HEADER = "api_key_header"
API_KEY_QUERY = "api_key_query"
NONE = auto()
API_KEY_HEADER = auto()
API_KEY_QUERY = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel):
return value
class LogMessage(BaseModel):
class LogStatus(Enum):
START = "start"
ERROR = "error"
SUCCESS = "success"
class LogStatus(StrEnum):
START = auto()
ERROR = auto()
SUCCESS = auto()
id: str
label: str = Field(..., description="The label of the log")
@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
BINARY_LINK = "binary_link"
VARIABLE = "variable"
FILE = "file"
LOG = "log"
BLOB_CHUNK = "blob_chunk"
RETRIEVER_RESOURCES = "retriever_resources"
class MessageType(StrEnum):
TEXT = auto()
IMAGE = auto()
LINK = auto()
BLOB = auto()
JSON = auto()
IMAGE_LINK = auto()
BINARY_LINK = auto()
VARIABLE = auto()
FILE = auto()
LOG = auto()
BLOB_CHUNK = auto()
RETRIEVER_RESOURCES = auto()
type: MessageType = MessageType.TEXT
"""
@ -250,29 +249,29 @@ class ToolParameter(PluginParameter):
Overrides type
"""
class ToolParameterType(enum.StrEnum):
class ToolParameterType(StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
ANY = PluginParameterType.ANY.value
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
APP_SELECTOR = PluginParameterType.APP_SELECTOR
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
ANY = PluginParameterType.ANY
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
# MCP object and array type parameters
ARRAY = MCPServerParameterType.ARRAY.value
OBJECT = MCPServerParameterType.OBJECT.value
ARRAY = MCPServerParameterType.ARRAY
OBJECT = MCPServerParameterType.OBJECT
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@ -280,10 +279,10 @@ class ToolParameter(PluginParameter):
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
class ToolParameterForm(StrEnum):
SCHEMA = auto() # should be set while adding tool
FORM = auto() # should be set before invoking tool
LLM = auto() # will be set by LLM
type: ToolParameterType = Field(..., description="The type of the parameter")
human_description: I18nObject | None = Field(default=None, description="The description presented to the user")
@ -448,14 +447,14 @@ class ToolLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")
class ToolInvokeFrom(Enum):
class ToolInvokeFrom(StrEnum):
"""
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"
PLUGIN = "plugin"
WORKFLOW = auto()
AGENT = auto()
PLUGIN = auto()
class ToolSelector(BaseModel):
@ -480,9 +479,9 @@ class ToolSelector(BaseModel):
return self.model_dump()
class CredentialType(enum.StrEnum):
class CredentialType(StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
OAUTH2 = auto()
def get_name(self):
if self == CredentialType.API_KEY:

View File

@ -1,4 +1,5 @@
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any
@ -8,28 +9,25 @@ from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}):
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
def _load_yaml_file(*, file_path: str):
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
return yaml_content
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
@lru_cache(maxsize=128)
def load_yaml_file_cached(file_path: str) -> Any:
"""
Cached version of load_yaml_file for static configuration files.
Only use for files that don't change during runtime (e.g., position files)
:param file_path: the path of the YAML file
:return: an object of the YAML content
"""
return _load_yaml_file(file_path=file_path)

View File

@ -1,4 +1,4 @@
from enum import Enum, StrEnum
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
class AgentOldVersionModelFeatures(StrEnum):
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from pydantic import BaseModel, Field
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
Generate Route Chunk.
"""
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()
type: ChunkType = Field(..., description="generate route chunk type")

View File

@ -250,7 +250,7 @@ class KnowledgeRetrievalNode(Node):
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@ -282,7 +282,7 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":

View File

@ -397,7 +397,7 @@ class WorkflowEntry:
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
if variable_pool.get(variable_selector) and variable_selector[0] == ENVIRONMENT_VARIABLE_NODE_ID:
continue
# fetch variable node id from variable selector

View File

@ -9,19 +9,19 @@ import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional
logger = logging.getLogger(__name__)
class FileStatus(Enum):
class FileStatus(StrEnum):
"""File status enumeration"""
ACTIVE = "active" # Active status
ARCHIVED = "archived" # Archived
DELETED = "deleted" # Deleted (soft delete)
BACKUP = "backup" # Backup file
ACTIVE = auto() # Active status
ARCHIVED = auto() # Archived
DELETED = auto() # Deleted (soft delete)
BACKUP = auto() # Backup file
@dataclass

View File

@ -5,13 +5,13 @@ According to ClickZetta's permission model, different Volume types have differen
"""
import logging
from enum import Enum
from enum import StrEnum
from typing import Optional
logger = logging.getLogger(__name__)
class VolumePermission(Enum):
class VolumePermission(StrEnum):
"""Volume permission type enumeration"""
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission

View File

@ -7,7 +7,7 @@ eliminates the need for repetitive language switching logic.
"""
from dataclasses import dataclass
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional, Protocol
from flask import render_template
@ -17,26 +17,30 @@ from extensions.ext_mail import mail
from services.feature_service import BrandingModel, FeatureService
class EmailType(Enum):
class EmailType(StrEnum):
"""Enumeration of supported email types."""
RESET_PASSWORD = "reset_password"
INVITE_MEMBER = "invite_member"
EMAIL_CODE_LOGIN = "email_code_login"
CHANGE_EMAIL_OLD = "change_email_old"
CHANGE_EMAIL_NEW = "change_email_new"
CHANGE_EMAIL_COMPLETED = "change_email_completed"
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
ACCOUNT_DELETION_SUCCESS = "account_deletion_success"
ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification"
ENTERPRISE_CUSTOM = "enterprise_custom"
QUEUE_MONITOR_ALERT = "queue_monitor_alert"
DOCUMENT_CLEAN_NOTIFY = "document_clean_notify"
RESET_PASSWORD = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
INVITE_MEMBER = auto()
EMAIL_CODE_LOGIN = auto()
CHANGE_EMAIL_OLD = auto()
CHANGE_EMAIL_NEW = auto()
CHANGE_EMAIL_COMPLETED = auto()
OWNER_TRANSFER_CONFIRM = auto()
OWNER_TRANSFER_OLD_NOTIFY = auto()
OWNER_TRANSFER_NEW_NOTIFY = auto()
ACCOUNT_DELETION_SUCCESS = auto()
ACCOUNT_DELETION_VERIFICATION = auto()
ENTERPRISE_CUSTOM = auto()
QUEUE_MONITOR_ALERT = auto()
DOCUMENT_CLEAN_NOTIFY = auto()
EMAIL_REGISTER = auto()
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
class EmailLanguage(Enum):
class EmailLanguage(StrEnum):
"""Supported email languages with fallback handling."""
EN_US = "en-US"
@ -441,6 +445,54 @@ def create_default_email_config() -> EmailI18nConfig:
branded_template_path="clean_document_job_mail_template_zh-CN.html",
),
},
EmailType.EMAIL_REGISTER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Register Your {application_title} Account",
template_path="register_email_template_en-US.html",
branded_template_path="without-brand/register_email_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="注册您的 {application_title} 账户",
template_path="register_email_template_zh-CN.html",
branded_template_path="without-brand/register_email_template_zh-CN.html",
),
},
EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: {
EmailLanguage.EN_US: EmailTemplate(
subject="Register Your {application_title} Account",
template_path="register_email_when_account_exist_template_en-US.html",
branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="注册您的 {application_title} 账户",
template_path="register_email_when_account_exist_template_zh-CN.html",
branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html",
),
},
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_mail_when_account_not_exist_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html",
),
},
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
),
},
}
return EmailI18nConfig(templates=templates)

View File

@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw):
if isinstance(obj, dict) and "app" in obj:
obj = obj["app"]
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(obj.icon)
return None

View File

@ -224,35 +224,35 @@ class Dataset(Base):
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.document_name.value,
"name": BuiltInField.document_name,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.uploader.value,
"name": BuiltInField.uploader,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.upload_date.value,
"name": BuiltInField.upload_date,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.last_update_date.value,
"name": BuiltInField.last_update_date,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.source.value,
"name": BuiltInField.source,
"type": "string",
}
)
@ -544,7 +544,7 @@ class Document(Base):
"id": "built-in",
"name": BuiltInField.source,
"type": "string",
"value": MetadataDataSource[self.data_source_type].value,
"value": MetadataDataSource[self.data_source_type],
}
)
return built_in_fields

View File

@ -3,7 +3,7 @@ import re
import uuid
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
@ -60,9 +60,9 @@ class AppMode(StrEnum):
raise ValueError(f"invalid mode value {value}")
class IconType(Enum):
IMAGE = "image"
EMOJI = "emoji"
class IconType(StrEnum):
IMAGE = auto()
EMOJI = auto()
class App(Base):
@ -147,15 +147,15 @@ class App(Base):
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
"strategy", ""
) in {"function_call", "react"}:
self.mode = AppMode.AGENT_CHAT.value
self.mode = AppMode.AGENT_CHAT
db.session.commit()
return True
return False
@property
def mode_compatible_with_agent(self) -> str:
if self.mode == AppMode.CHAT.value and self.is_agent:
return AppMode.AGENT_CHAT.value
if self.mode == AppMode.CHAT and self.is_agent:
return AppMode.AGENT_CHAT
return str(self.mode)
@ -712,7 +712,7 @@ class Conversation(Base):
model_config = {}
app_model_config: Optional[AppModelConfig] = None
if self.mode == AppMode.ADVANCED_CHAT.value:
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = override_model_configs
@ -1459,6 +1459,14 @@ class OperationLog(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class DefaultEndUserSessionID(StrEnum):
"""
End User Session ID enum.
"""
DEFAULT_SESSION_ID = "DEFAULT-USER"
class EndUser(Base, UserMixin):
__tablename__ = "end_users"
__table_args__ = (

View File

@ -1,5 +1,5 @@
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from functools import cached_property
from typing import Optional
@ -12,9 +12,9 @@ from .engine import db
from .types import StringUUID
class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
class ProviderType(StrEnum):
CUSTOM = auto()
SYSTEM = auto()
@staticmethod
def value_of(value: str) -> "ProviderType":
@ -24,14 +24,14 @@ class ProviderType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = "free"
FREE = auto()
"""third-party free quota"""
TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""
@staticmethod

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
@ -41,13 +41,13 @@ from .types import EnumText, StringUUID
logger = logging.getLogger(__name__)
class WorkflowType(Enum):
class WorkflowType(StrEnum):
"""
Workflow Type Enum
"""
WORKFLOW = "workflow"
CHAT = "chat"
WORKFLOW = auto()
CHAT = auto()
@classmethod
def value_of(cls, value: str) -> "WorkflowType":
@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base):
return extras
class WorkflowAppLogCreatedFrom(Enum):
class WorkflowAppLogCreatedFrom(StrEnum):
"""
Workflow App Log Created From Enum
"""

View File

@ -169,6 +169,8 @@ dev = [
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy~=1.17.1",
"locust>=2.40.4",
"sseclient-py>=1.8.0",
]
############################################################

View File

@ -37,7 +37,6 @@ from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError,
@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import (
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
from tasks.mail_reset_password_task import send_reset_password_mail_task
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
from tasks.mail_reset_password_task import (
send_reset_password_mail_task,
send_reset_password_mail_task_when_account_not_exist,
)
logger = logging.getLogger(__name__)
@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
@ -95,6 +99,7 @@ class AccountService:
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@ -171,7 +176,7 @@ class AccountService:
account = db.session.query(Account).filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
@ -296,7 +301,9 @@ class AccountService:
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
raise EmailCodeAccountDeletionRateLimitExceededError()
raise EmailCodeAccountDeletionRateLimitExceededError(
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
)
send_account_deletion_verification_code.delay(to=email, code=code)
@ -435,6 +442,7 @@ class AccountService:
account: Optional[Account] = None,
email: Optional[str] = None,
language: str = "en-US",
is_allow_register: bool = False,
):
account_email = account.email if account else email
if account_email is None:
@ -443,18 +451,59 @@ class AccountService:
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
code, token = cls.generate_reset_password_token(account_email, account)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
if account:
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
else:
send_reset_password_mail_task_when_account_not_exist.delay(
language=language,
to=account_email,
is_allow_register=is_allow_register,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_email_register_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.email_register_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
code, token = cls.generate_email_register_token(account_email)
if account:
send_email_register_mail_task_when_account_exist.delay(
language=language,
to=account_email,
account_name=account.name,
)
else:
send_email_register_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.email_register_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_change_email_email(
cls,
@ -473,7 +522,7 @@ class AccountService:
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError()
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
@ -517,7 +566,7 @@ class AccountService:
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
raise OwnerTransferRateLimitExceededError()
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
code, token = cls.generate_owner_transfer_token(account_email, account)
workspace_name = workspace_name or ""
@ -587,6 +636,19 @@ class AccountService:
)
return code, token
@classmethod
def generate_email_register_token(
cls,
email: str,
code: Optional[str] = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
return code, token
@classmethod
def generate_change_email_token(
cls,
@ -625,6 +687,10 @@ class AccountService:
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
@classmethod
def revoke_email_register_token(cls, token: str):
TokenManager.revoke_token(token, "email_register")
@classmethod
def revoke_change_email_token(cls, token: str):
TokenManager.revoke_token(token, "change_email")
@ -637,6 +703,10 @@ class AccountService:
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def get_email_register_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "email_register")
@classmethod
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "change_email")
@ -658,7 +728,7 @@ class AccountService:
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
@ -744,6 +814,16 @@ class AccountService:
count = int(count) + 1
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=None)
def add_email_register_error_rate_limit(email: str) -> None:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
@ -763,6 +843,24 @@ class AccountService:
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=False)
def is_email_register_error_rate_limit(email: str) -> bool:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
@redis_fallback(default_return=None)
def reset_email_register_error_rate_limit(email: str):
key = f"email_register_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str):

View File

@ -32,14 +32,14 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
@ -73,7 +73,7 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),

View File

@ -60,7 +60,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
@ -69,7 +69,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
@ -78,7 +78,7 @@ class AppGenerateService:
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
@ -87,7 +87,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -103,7 +103,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -154,14 +154,14 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
@ -173,14 +173,14 @@ class AppGenerateService:
@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(

View File

@ -40,15 +40,15 @@ class AppService:
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
filters.append(App.mode == AppMode.WORKFLOW.value)
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
filters.append(App.mode == AppMode.COMPLETION.value)
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
filters.append(App.mode == AppMode.CHAT.value)
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@ -171,7 +171,7 @@ class AppService:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app

View File

@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
@ -88,7 +88,7 @@ class AudioService:
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:

View File

@ -222,8 +222,8 @@ class ConversationService:
# Filter for variables created after the last_id
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
# Apply limit to query
query_stmt = stmt.limit(limit) # Get one extra to check if there are more
# Apply limit to query: fetch one extra row to determine has_more
query_stmt = stmt.limit(limit + 1)
rows = session.scalars(query_stmt).all()
has_more = False

View File

@ -1004,7 +1004,7 @@ class DocumentService:
if dataset.built_in_field_enabled:
if document.doc_metadata:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = name
doc_metadata[BuiltInField.document_name] = name
document.doc_metadata = doc_metadata
document.name = name
@ -2365,7 +2365,22 @@ class SegmentService:
if segment.enabled:
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
# Get child chunk IDs before parent segment is deleted
child_node_ids = []
if segment.index_node_id:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
db.session.delete(segment)
# update document word count
assert document.word_count is not None
@ -2375,9 +2390,13 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
segments_info = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
@ -2387,18 +2406,36 @@ class SegmentService:
.all()
)
if not segments:
if not segments_info:
return
index_node_ids = [seg.index_node_id for seg in segments]
total_words = sum(seg.word_count for seg in segments)
index_node_ids = [info[0] for info in segments_info]
segment_db_ids = [info[1] for info in segments_info]
total_words = sum(info[2] for info in segments_info if info[2] is not None)
# Get child chunk IDs before parent segments are deleted
child_node_ids = []
if index_node_ids:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
document.word_count = (
document.word_count - total_words if document.word_count and document.word_count > total_words else 0
)
db.session.add(document)
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
# Delete database records
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()

View File

@ -229,7 +229,7 @@ class MessageService:
model_manager = ModelManager()
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)

Some files were not shown because too many files have changed in this diff Show More