mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
commit
b0e815c3c7
|
|
@ -23,11 +23,37 @@ jobs:
|
|||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
|
||||
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: oxlint
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpx oxlint --fix
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
|
|
|
|||
|
|
@ -227,3 +227,7 @@ web/public/fallback-*.js
|
|||
.roo/
|
||||
api/.env.backup
|
||||
/clickzetta
|
||||
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
|
|
@ -540,6 +540,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
|||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
|
||||
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
|
||||
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
|
|
|
|||
|
|
@ -477,12 +477,12 @@ def convert_to_agent_apps():
|
|||
click.echo(f"Converting app: {app.id}")
|
||||
|
||||
try:
|
||||
app.mode = AppMode.AGENT_CHAT.value
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
db.session.commit()
|
||||
|
||||
# update conversation mode to agent
|
||||
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
|
||||
{Conversation.mode: AppMode.AGENT_CHAT.value}
|
||||
{Conversation.mode: AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -31,6 +31,12 @@ class SecurityConfig(BaseSettings):
|
|||
description="Duration in minutes for which a password reset token remains valid",
|
||||
default=5,
|
||||
)
|
||||
|
||||
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="Duration in minutes for which a email register token remains valid",
|
||||
default=5,
|
||||
)
|
||||
|
||||
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="Duration in minutes for which a change email token remains valid",
|
||||
default=5,
|
||||
|
|
@ -661,6 +667,11 @@ class AuthConfig(BaseSettings):
|
|||
default=86400,
|
||||
)
|
||||
|
||||
EMAIL_REGISTER_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying email register after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
|
|
@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings):
|
|||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(enum.StrEnum):
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
|||
# workflow default mode
|
||||
AppMode.WORKFLOW: {
|
||||
"app": {
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"mode": AppMode.WORKFLOW,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
}
|
||||
|
|
@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
|||
# completion default mode
|
||||
AppMode.COMPLETION: {
|
||||
"app": {
|
||||
"mode": AppMode.COMPLETION.value,
|
||||
"mode": AppMode.COMPLETION,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
|
|
@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
|||
# chat default mode
|
||||
AppMode.CHAT: {
|
||||
"app": {
|
||||
"mode": AppMode.CHAT.value,
|
||||
"mode": AppMode.CHAT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
|
|
@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
|||
# advanced-chat default mode
|
||||
AppMode.ADVANCED_CHAT: {
|
||||
"app": {
|
||||
"mode": AppMode.ADVANCED_CHAT.value,
|
||||
"mode": AppMode.ADVANCED_CHAT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
|
|
@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
|||
# agent-chat default mode
|
||||
AppMode.AGENT_CHAT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT_CHAT.value,
|
||||
"mode": AppMode.AGENT_CHAT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ from .auth import (
|
|||
activate, # pyright: ignore[reportUnusedImport]
|
||||
data_source_bearer_auth, # pyright: ignore[reportUnusedImport]
|
||||
data_source_oauth, # pyright: ignore[reportUnusedImport]
|
||||
email_register, # pyright: ignore[reportUnusedImport]
|
||||
forgot_password, # pyright: ignore[reportUnusedImport]
|
||||
login, # pyright: ignore[reportUnusedImport]
|
||||
oauth, # pyright: ignore[reportUnusedImport]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,26 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
@api.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -19,6 +33,3 @@ class AdvancedPromptTemplateList(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -9,7 +9,18 @@ from models.model import AppMode
|
|||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -23,6 +34,3 @@ class AgentLogApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
|
||||
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from typing import Literal
|
|||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
|
|
@ -21,7 +21,23 @@ from libs.login import login_required
|
|||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@api.doc("annotation_reply_action")
|
||||
@api.doc(description="Enable or disable annotation reply for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Action completed successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -43,7 +59,13 @@ class AnnotationReplyActionApi(Resource):
|
|||
return result, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-setting")
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@api.doc("get_annotation_setting")
|
||||
@api.doc(description="Get annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotation settings retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -56,7 +78,23 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||
return result, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@api.doc("update_annotation_setting")
|
||||
@api.doc(description="Update annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Settings updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -75,7 +113,13 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
return result, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@api.doc("get_annotation_reply_action_status")
|
||||
@api.doc(description="Get status of annotation reply action job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -99,7 +143,19 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@api.doc("list_annotations")
|
||||
@api.doc(description="Get annotations for an app with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@api.response(200, "Annotations retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -122,6 +178,21 @@ class AnnotationApi(Resource):
|
|||
}
|
||||
return response, 200
|
||||
|
||||
@api.doc("create_annotation")
|
||||
@api.doc(description="Create a new annotation for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -168,7 +239,13 @@ class AnnotationApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
class AnnotationExportApi(Resource):
|
||||
@api.doc("export_annotations")
|
||||
@api.doc(description="Export all annotations for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -182,7 +259,14 @@ class AnnotationExportApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.doc("update_delete_annotation")
|
||||
@api.doc(description="Update or delete an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -214,7 +298,14 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@api.doc("batch_import_annotations")
|
||||
@api.doc(description="Batch import annotations from CSV file")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Batch import started successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "No file uploaded or too many files")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -239,7 +330,13 @@ class AnnotationBatchImportApi(Resource):
|
|||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@api.doc("get_batch_import_status")
|
||||
@api.doc(description="Get status of batch import job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -262,7 +359,20 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@api.doc("list_annotation_hit_histories")
|
||||
@api.doc(description="Get hit histories for an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
)
|
||||
@api.response(
|
||||
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
|
||||
)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -285,17 +395,3 @@ class AnnotationHitHistoryListApi(Resource):
|
|||
"page": page,
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
api.add_resource(
|
||||
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
|
||||
)
|
||||
api.add_resource(AnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
|
||||
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
|
||||
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ import uuid
|
|||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -34,7 +34,27 @@ def _validate_description_length(description):
|
|||
return description
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
@api.doc(description="Get list of applications with pagination and filtering")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@api.response(200, "Success", app_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -91,6 +111,24 @@ class AppListApi(Resource):
|
|||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
@api.doc("create_app")
|
||||
@api.doc(description="Create a new application")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App created successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -124,7 +162,12 @@ class AppListApi(Resource):
|
|||
return app, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>")
|
||||
class AppApi(Resource):
|
||||
@api.doc("get_app_detail")
|
||||
@api.doc(description="Get application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Success", app_detail_fields_with_site)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -143,6 +186,26 @@ class AppApi(Resource):
|
|||
|
||||
return app_model
|
||||
|
||||
@api.doc("update_app")
|
||||
@api.doc(description="Update application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -181,6 +244,11 @@ class AppApi(Resource):
|
|||
|
||||
return app_model
|
||||
|
||||
@api.doc("delete_app")
|
||||
@api.doc(description="Delete application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(204, "App deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -197,7 +265,25 @@ class AppApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
class AppCopyApi(Resource):
|
||||
@api.doc("copy_app")
|
||||
@api.doc(description="Create a copy of an existing application")
|
||||
@api.doc(params={"app_id": "Application ID to copy"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App copied successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -239,7 +325,22 @@ class AppCopyApi(Resource):
|
|||
return app, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export")
|
||||
class AppExportApi(Resource):
|
||||
@api.doc("export_app")
|
||||
@api.doc(description="Export application configuration as DSL")
|
||||
@api.doc(params={"app_id": "Application ID to export"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -263,7 +364,13 @@ class AppExportApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
|
||||
@api.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -284,7 +391,23 @@ class AppNameApi(Resource):
|
|||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/icon")
|
||||
class AppIconApi(Resource):
|
||||
@api.doc("update_app_icon")
|
||||
@api.doc(description="Update application icon")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Icon updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -306,7 +429,18 @@ class AppIconApi(Resource):
|
|||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site-enable")
|
||||
class AppSiteStatus(Resource):
|
||||
@api.doc("update_app_site_status")
|
||||
@api.doc(description="Enable or disable app site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -327,7 +461,18 @@ class AppSiteStatus(Resource):
|
|||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/api-enable")
|
||||
class AppApiStatus(Resource):
|
||||
@api.doc("update_app_api_status")
|
||||
@api.doc(description="Enable or disable app API")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "API status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -348,7 +493,12 @@ class AppApiStatus(Resource):
|
|||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace")
|
||||
class AppTraceApi(Resource):
|
||||
@api.doc("get_app_trace")
|
||||
@api.doc(description="Get app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Trace configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -358,6 +508,20 @@ class AppTraceApi(Resource):
|
|||
|
||||
return app_trace_config
|
||||
|
||||
@api.doc("update_app_trace")
|
||||
@api.doc(description="Update app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trace configuration updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -377,14 +541,3 @@ class AppTraceApi(Resource):
|
|||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(AppListApi, "/apps")
|
||||
api.add_resource(AppApi, "/apps/<uuid:app_id>")
|
||||
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
|
||||
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
|
||||
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
|
||||
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
|
||||
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
|
||||
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
|
||||
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
|
|
@ -34,7 +34,18 @@ from services.errors.audio import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@api.doc("chat_message_audio_transcript")
|
||||
@api.doc(description="Transcript audio to text for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
)
|
||||
@api.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@api.response(413, "Audio file too large")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -76,7 +87,24 @@ class ChatMessageAudioApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio")
|
||||
class ChatMessageTextApi(Resource):
|
||||
@api.doc("chat_message_text_to_speech")
|
||||
@api.doc(description="Convert text to speech for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"text": fields.String(required=True, description="Text to convert to speech"),
|
||||
"voice": fields.String(description="Voice to use for TTS"),
|
||||
"streaming": fields.Boolean(description="Whether to stream the audio"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Text to speech conversion successful")
|
||||
@api.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -124,7 +152,14 @@ class ChatMessageTextApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
class TextModesApi(Resource):
|
||||
@api.doc("get_text_to_speech_voices")
|
||||
@api.doc(description="Get available TTS voices for a specific language")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
|
||||
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
|
||||
@api.response(400, "Invalid language parameter")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -164,8 +199,3 @@ class TextModesApi(Resource):
|
|||
except Exception as e:
|
||||
logger.exception("Failed to handle get request to TextModesApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
|
||||
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
|
||||
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
|
|
@ -38,7 +38,27 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
# define completion message api for user
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages")
|
||||
class CompletionMessageApi(Resource):
|
||||
@api.doc("create_completion_message")
|
||||
@api.doc(description="Generate completion message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Completion generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -86,7 +106,12 @@ class CompletionMessageApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
class CompletionMessageStopApi(Resource):
|
||||
@api.doc("stop_completion_message")
|
||||
@api.doc(description="Stop a running completion message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -99,7 +124,29 @@ class CompletionMessageStopApi(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageApi(Resource):
|
||||
@api.doc("create_chat_message")
|
||||
@api.doc(description="Generate chat message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Chat message generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -161,7 +208,12 @@ class ChatMessageApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
class ChatMessageStopApi(Resource):
|
||||
@api.doc("stop_chat_message")
|
||||
@api.doc(description="Stop a running chat message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -172,9 +224,3 @@ class ChatMessageStopApi(Resource):
|
|||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
|
||||
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
|
||||
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy import func, or_
|
|||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -28,7 +28,29 @@ from services.conversation_service import ConversationService
|
|||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations")
|
||||
class CompletionConversationApi(Resource):
|
||||
@api.doc("list_completion_conversations")
|
||||
@api.doc(description="Get completion conversations with pagination and filtering")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", conversation_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -101,7 +123,14 @@ class CompletionConversationApi(Resource):
|
|||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
@api.doc("get_completion_conversation")
|
||||
@api.doc(description="Get completion conversation details with messages")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_message_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -114,6 +143,12 @@ class CompletionConversationDetailApi(Resource):
|
|||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@api.doc("delete_completion_conversation")
|
||||
@api.doc(description="Delete a completion conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -133,7 +168,38 @@ class CompletionConversationDetailApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
class ChatConversationApi(Resource):
|
||||
@api.doc("list_chat_conversations")
|
||||
@api.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@api.response(200, "Success", conversation_with_summary_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -241,7 +307,7 @@ class ChatConversationApi(Resource):
|
|||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
match args["sort_by"]:
|
||||
|
|
@ -261,7 +327,14 @@ class ChatConversationApi(Resource):
|
|||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
class ChatConversationDetailApi(Resource):
|
||||
@api.doc("get_chat_conversation")
|
||||
@api.doc(description="Get chat conversation details")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -274,6 +347,12 @@ class ChatConversationDetailApi(Resource):
|
|||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@api.doc("delete_chat_conversation")
|
||||
@api.doc(description="Delete a chat conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
|
|
@ -293,12 +372,6 @@ class ChatConversationDetailApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
|
||||
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
|
||||
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -12,7 +12,17 @@ from models import ConversationVariable
|
|||
from models.model import AppMode
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/conversation-variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -55,6 +65,3 @@ class ConversationVariablesApi(Resource):
|
|||
for row in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -22,7 +22,23 @@ from models import App
|
|||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@api.doc("generate_rule_config")
|
||||
@api.doc(description="Generate rule configuration using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Rule configuration generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -53,7 +69,26 @@ class RuleGenerateApi(Resource):
|
|||
return rules
|
||||
|
||||
|
||||
@console_ns.route("/rule-code-generate")
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@api.doc("generate_rule_code")
|
||||
@api.doc(description="Generate code rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Code rules generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -85,7 +120,22 @@ class RuleCodeGenerateApi(Resource):
|
|||
return code_result
|
||||
|
||||
|
||||
@console_ns.route("/rule-structured-output-generate")
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@api.doc("generate_structured_output")
|
||||
@api.doc(description="Generate structured output rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Structured output generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -114,7 +164,27 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
return structured_output
|
||||
|
||||
|
||||
@console_ns.route("/instruction-generate")
|
||||
class InstructionGenerateApi(Resource):
|
||||
@api.doc("generate_instruction")
|
||||
@api.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Instruction generated successfully")
|
||||
@api.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -203,7 +273,21 @@ class InstructionGenerateApi(Resource):
|
|||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/instruction-generate/template")
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@api.doc("get_instruction_template")
|
||||
@api.doc(description="Get instruction generation template")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Template retrieved successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -222,10 +306,3 @@ class InstructionGenerationTemplateApi(Resource):
|
|||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
|
||||
api.add_resource(InstructionGenerateApi, "/instruction-generate")
|
||||
api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import json
|
|||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -19,7 +19,12 @@ class AppMCPServerStatus(StrEnum):
|
|||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@api.doc("get_app_mcp_server")
|
||||
@api.doc(description="Get MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -29,6 +34,20 @@ class AppMCPServerController(Resource):
|
|||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@api.doc("create_app_mcp_server")
|
||||
@api.doc(description="Create MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -59,6 +78,23 @@ class AppMCPServerController(Resource):
|
|||
db.session.commit()
|
||||
return server
|
||||
|
||||
@api.doc("update_app_mcp_server")
|
||||
@api.doc(description="Update MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
"status": fields.String(description="Server status"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -94,7 +130,14 @@ class AppMCPServerController(Resource):
|
|||
return server
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@api.doc("refresh_app_mcp_server")
|
||||
@api.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@api.doc(params={"server_id": "Server ID"})
|
||||
@api.response(200, "MCP server refreshed successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -113,7 +156,3 @@ class AppMCPServerRefreshController(Resource):
|
|||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx.inputs import int_range
|
|||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -37,6 +37,7 @@ from services.message_service import MessageService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
|
|
@ -44,6 +45,17 @@ class ChatMessageListApi(Resource):
|
|||
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||
}
|
||||
|
||||
@api.doc("list_chat_messages")
|
||||
@api.doc(description="Get chat messages for a conversation with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
|
|
@ -117,7 +129,23 @@ class ChatMessageListApi(Resource):
|
|||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@api.doc("create_message_feedback")
|
||||
@api.doc(description="Create or update message feedback (like/dislike)")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Feedback updated successfully")
|
||||
@api.response(404, "Message not found")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -162,7 +190,24 @@ class MessageFeedbackApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -186,7 +231,16 @@ class MessageAnnotationApi(Resource):
|
|||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
@api.doc(description="Get count of message annotations for the app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -197,7 +251,17 @@ class MessageAnnotationCountApi(Resource):
|
|||
return {"count": count}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(Resource):
|
||||
@api.doc("get_message_suggested_questions")
|
||||
@api.doc(description="Get suggested questions for a message")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
|
||||
)
|
||||
@api.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -230,7 +294,13 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
return {"data": questions}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@api.doc("get_message")
|
||||
@api.doc(description="Get message details by ID")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||
@api.response(404, "Message not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -245,11 +315,3 @@ class MessageApi(Resource):
|
|||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return message
|
||||
|
||||
|
||||
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
|
||||
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
|
||||
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
|
||||
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ import json
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
|
|
@ -13,13 +14,39 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@api.doc("update_app_model_config")
|
||||
@api.doc(description="Update application model configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ModelConfigRequest",
|
||||
{
|
||||
"provider": fields.String(description="Model provider"),
|
||||
"model": fields.String(description="Model name"),
|
||||
"configs": fields.Raw(description="Model configuration parameters"),
|
||||
"opening_statement": fields.String(description="Opening statement"),
|
||||
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
|
||||
"more_like_this": fields.Raw(description="More like this configuration"),
|
||||
"speech_to_text": fields.Raw(description="Speech to text configuration"),
|
||||
"text_to_speech": fields.Raw(description="Text to speech configuration"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"tools": fields.List(fields.Raw(), description="Available tools"),
|
||||
"dataset_configs": fields.Raw(description="Dataset configurations"),
|
||||
"agent_mode": fields.Raw(description="Agent mode configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Model configuration updated successfully")
|
||||
@api.response(400, "Invalid configuration")
|
||||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -47,7 +74,7 @@ class ModelConfigResource(Resource):
|
|||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
|
|
@ -150,6 +177,3 @@ class ModelConfigResource(Resource):
|
|||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")
|
||||
|
|
|
|||
|
|
@ -1,18 +1,31 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||
class TraceAppConfigApi(Resource):
|
||||
"""
|
||||
Manage trace app configurations
|
||||
"""
|
||||
|
||||
@api.doc("get_trace_app_config")
|
||||
@api.doc(description="Get tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -29,6 +42,22 @@ class TraceAppConfigApi(Resource):
|
|||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("create_trace_app_config")
|
||||
@api.doc(description="Create a new tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
@api.response(400, "Invalid request parameters or configuration already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -51,6 +80,20 @@ class TraceAppConfigApi(Resource):
|
|||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("update_trace_app_config")
|
||||
@api.doc(description="Update an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -71,6 +114,16 @@ class TraceAppConfigApi(Resource):
|
|||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("delete_trace_app_config")
|
||||
@api.doc(description="Delete an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@api.response(204, "Tracing configuration deleted successfully")
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -87,6 +140,3 @@ class TraceAppConfigApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -36,7 +36,39 @@ def parse_app_site_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@api.doc("update_app_site")
|
||||
@api.doc(description="Update application site configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"description": fields.String(description="Site description"),
|
||||
"default_language": fields.String(description="Default language"),
|
||||
"chat_color_theme": fields.String(description="Chat color theme"),
|
||||
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
|
||||
"customize_domain": fields.String(description="Custom domain"),
|
||||
"copyright": fields.String(description="Copyright text"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"customize_token_strategy": fields.String(
|
||||
enum=["must", "allow", "not_allow"], description="Token strategy"
|
||||
),
|
||||
"prompt_public": fields.Boolean(description="Make prompt public"),
|
||||
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site configuration updated successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -84,7 +116,14 @@ class AppSite(Resource):
|
|||
return site
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
class AppSiteAccessTokenReset(Resource):
|
||||
@api.doc("reset_app_site_access_token")
|
||||
@api.doc(description="Reset access token for application site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Access token reset successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@api.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -108,7 +147,3 @@ class AppSiteAccessTokenReset(Resource):
|
|||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
|
||||
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import pytz
|
|||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -17,7 +17,21 @@ from libs.login import login_required
|
|||
from models import AppMode, Message
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@api.doc("get_daily_message_statistics")
|
||||
@api.doc(description="Get daily message statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily message count data")),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -74,7 +88,21 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily conversation count data")),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -126,7 +154,21 @@ class DailyConversationStatistic(Resource):
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily terminal count data")),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -183,7 +225,21 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily token cost data")),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -243,7 +299,21 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average session interaction data")),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -319,7 +389,21 @@ ORDER BY
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="User satisfaction rate data")),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -385,7 +469,21 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average response time data")),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -442,7 +540,21 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Tokens per second data")),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -500,13 +612,3 @@ WHERE
|
|||
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
|
||||
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
|
|
|
|||
|
|
@ -721,7 +721,6 @@ class WorkflowByIdApi(Resource):
|
|||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,155 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
EmailRegisterLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args["language"] in languages:
|
||||
language = args["language"]
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class EmailRegisterCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
token_data = AccountService.get_email_register_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_email_register_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
class EmailRegisterResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get register data
|
||||
register_data = AccountService.get_email_register_data(args["token"])
|
||||
if not register_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if register_data.get("phase", "") != "register":
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
|
||||
email = register_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args["password_confirm"])
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email, password) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=email,
|
||||
name=email,
|
||||
password=password,
|
||||
interface_language=languages[0],
|
||||
)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
|
||||
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
|
||||
api.add_resource(EmailRegisterResetApi, "/email-register")
|
||||
|
|
@ -27,21 +27,43 @@ class InvalidTokenError(BaseHTTPException):
|
|||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "password_reset_rate_limit_exceeded"
|
||||
description = "Too many password reset emails have been sent. Please try again in 1 minute."
|
||||
description = "Too many password reset emails have been sent. Please try again in {minutes} minutes."
|
||||
code = 429
|
||||
|
||||
def __init__(self, minutes: int = 1):
|
||||
description = self.description.format(minutes=int(minutes)) if self.description else None
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class EmailRegisterRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_register_rate_limit_exceeded"
|
||||
description = "Too many email register emails have been sent. Please try again in {minutes} minutes."
|
||||
code = 429
|
||||
|
||||
def __init__(self, minutes: int = 1):
|
||||
description = self.description.format(minutes=int(minutes)) if self.description else None
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class EmailChangeRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_change_rate_limit_exceeded"
|
||||
description = "Too many email change emails have been sent. Please try again in 1 minute."
|
||||
description = "Too many email change emails have been sent. Please try again in {minutes} minutes."
|
||||
code = 429
|
||||
|
||||
def __init__(self, minutes: int = 1):
|
||||
description = self.description.format(minutes=int(minutes)) if self.description else None
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class OwnerTransferRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "owner_transfer_rate_limit_exceeded"
|
||||
description = "Too many owner transfer emails have been sent. Please try again in 1 minute."
|
||||
description = "Too many owner transfer emails have been sent. Please try again in {minutes} minutes."
|
||||
code = 429
|
||||
|
||||
def __init__(self, minutes: int = 1):
|
||||
description = self.description.format(minutes=int(minutes)) if self.description else None
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class EmailCodeError(BaseHTTPException):
|
||||
error_code = "email_code_error"
|
||||
|
|
@ -69,15 +91,23 @@ class EmailPasswordLoginLimitError(BaseHTTPException):
|
|||
|
||||
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_login_rate_limit_exceeded"
|
||||
description = "Too many login emails have been sent. Please try again in 5 minutes."
|
||||
description = "Too many login emails have been sent. Please try again in {minutes} minutes."
|
||||
code = 429
|
||||
|
||||
def __init__(self, minutes: int = 5):
|
||||
description = self.description.format(minutes=int(minutes)) if self.description else None
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_account_deletion_rate_limit_exceeded"
|
||||
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
|
||||
description = "Too many account deletion emails have been sent. Please try again in {minutes} minutes."
|
||||
code = 429
|
||||
|
||||
def __init__(self, minutes: int = 5):
|
||||
description = self.description.format(minutes=int(minutes)) if self.description else None
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class EmailPasswordResetLimitError(BaseHTTPException):
|
||||
error_code = "email_password_reset_limit"
|
||||
|
|
@ -85,6 +115,12 @@ class EmailPasswordResetLimitError(BaseHTTPException):
|
|||
code = 429
|
||||
|
||||
|
||||
class EmailRegisterLimitError(BaseHTTPException):
|
||||
error_code = "email_register_limit"
|
||||
description = "Too many failed email register attempts. Please try again in 24 hours."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailChangeLimitError(BaseHTTPException):
|
||||
error_code = "email_change_limit"
|
||||
description = "Too many failed email change attempts. Please try again in 24 hours."
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from flask_restx import Resource, fields, reqparse
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
|
|
@ -15,7 +14,7 @@ from controllers.console.auth.error import (
|
|||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -23,8 +22,6 @@ from libs.helper import email, extract_remote_ip
|
|||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
|
|
@ -73,15 +70,13 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
return {"result": "fail", "data": token, "code": "account_not_found"}
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args["email"],
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
|
@ -207,7 +202,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
self._create_new_account(email, args["password_confirm"])
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
@ -227,18 +222,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
def _create_new_account(self, email, password):
|
||||
# Create new account if allowed
|
||||
try:
|
||||
AccountService.create_account_and_tenant(
|
||||
email=email,
|
||||
name=email,
|
||||
password=password,
|
||||
interface_language=languages[0],
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except WorkspacesLimitExceededError:
|
||||
pass
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from controllers.console.error import (
|
|||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
|
|
@ -44,10 +43,9 @@ class LoginApi(Resource):
|
|||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
parser.add_argument("password", type=str, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
|
|
@ -61,11 +59,6 @@ class LoginApi(Resource):
|
|||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
|
|
@ -80,12 +73,6 @@ class LoginApi(Resource):
|
|||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
return {"result": "fail", "data": token, "code": "account_not_found"}
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
|
|
@ -133,13 +120,12 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, language=language)
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args["email"],
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
|||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -183,7 +184,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
|
||||
if not account:
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
raise AccountNotFoundError()
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
"This email account has been deleted within the past "
|
||||
"30 days and is temporarily unavailable for new account registration"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource):
|
|||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
|
|
|||
|
|
@ -242,6 +242,19 @@ def email_password_login_enabled(view: Callable[P, R]):
|
|||
return decorated
|
||||
|
||||
|
||||
def email_register_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.is_allow_register:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# otherwise, return 403
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class PluginUploadFileApi(Resource):
|
|||
filename=filename,
|
||||
mimetype=mimetype,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
|
|
|
|||
|
|
@ -8,11 +8,10 @@ from flask_restx import reqparse
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import EndUser
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
|
@ -28,7 +27,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
try:
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
|
|
@ -42,7 +41,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
|
|
@ -73,7 +72,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
|||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class MCPAppApi(Resource):
|
|||
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
|
||||
"""Get and convert user input form"""
|
||||
# Get raw user input form based on app mode
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if not app.workflow:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from .dataset import (
|
|||
hit_testing, # pyright: ignore[reportUnusedImport]
|
||||
metadata, # pyright: ignore[reportUnusedImport]
|
||||
segment, # pyright: ignore[reportUnusedImport]
|
||||
upload_file, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
from .workspace import models # pyright: ignore[reportUnusedImport]
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class AppParameterApi(Resource):
|
|||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
"""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
|
|
|||
|
|
@ -340,6 +340,9 @@ class DatasetApi(DatasetApiResource):
|
|||
else:
|
||||
data["embedding_available"] = True
|
||||
|
||||
# force update search method to keyword_search if indexing_technique is economic
|
||||
data["retrieval_model_dict"]["search_method"] = "keyword_search"
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
class UploadFileApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_upload_file")
|
||||
@service_api_ns.doc(description="Get upload file information and download URL")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Upload file information retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, or upload file not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Get upload file information and download URL.
|
||||
|
||||
Returns information about an uploaded file including its download URL.
|
||||
"""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
|
@ -13,14 +13,13 @@ from sqlalchemy import select, update
|
|||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
|
@ -273,7 +272,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
|||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
|
|
@ -292,7 +291,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
|||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource):
|
|||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
|
@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
|||
|
||||
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
class AgentStrategyParameterType(StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
ANY = CommonParameterType.ANY.value
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
|
@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity):
|
|||
pass
|
||||
|
||||
|
||||
class AgentFeature(enum.StrEnum):
|
||||
class AgentFeature(StrEnum):
|
||||
"""
|
||||
Agent Feature, used to describe the features of the agent strategy.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class PromptTemplateConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("prompt_type"):
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
|
|
@ -90,7 +90,7 @@ class PromptTemplateConfigManager:
|
|||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
|
@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
|
|||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
class PromptType(StrEnum):
|
||||
"""
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
SIMPLE = auto()
|
||||
ADVANCED = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
|
|
@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
class RetrieveStrategy(StrEnum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
SINGLE = auto()
|
||||
MULTIPLE = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
|
|
@ -293,12 +293,12 @@ class AppConfig(BaseModel):
|
|||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
class EasyUIBasedAppModelConfigFrom(Enum):
|
||||
class EasyUIBasedAppModelConfigFrom(StrEnum):
|
||||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
|
||||
ARGS = "args"
|
||||
ARGS = auto()
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -510,15 +510,15 @@ class QueueStopEvent(AppQueueEvent):
|
|||
QueueStopEvent entity
|
||||
"""
|
||||
|
||||
class StopBy(Enum):
|
||||
class StopBy(StrEnum):
|
||||
"""
|
||||
Stop by enum
|
||||
"""
|
||||
|
||||
USER_MANUAL = "user-manual"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
INPUT_MODERATION = "input-moderation"
|
||||
USER_MANUAL = auto()
|
||||
ANNOTATION_REPLY = auto()
|
||||
OUTPUT_MODERATION = auto()
|
||||
INPUT_MODERATION = auto()
|
||||
|
||||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
|
@ -50,35 +50,37 @@ class WorkflowTaskState(TaskState):
|
|||
answer: str = ""
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
class StreamEvent(StrEnum):
|
||||
"""
|
||||
Stream event
|
||||
"""
|
||||
|
||||
PING = "ping"
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
LOOP_STARTED = "loop_started"
|
||||
LOOP_NEXT = "loop_next"
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
PING = auto()
|
||||
ERROR = auto()
|
||||
MESSAGE = auto()
|
||||
MESSAGE_END = auto()
|
||||
TTS_MESSAGE = auto()
|
||||
TTS_MESSAGE_END = auto()
|
||||
MESSAGE_FILE = auto()
|
||||
MESSAGE_REPLACE = auto()
|
||||
AGENT_THOUGHT = auto()
|
||||
AGENT_MESSAGE = auto()
|
||||
WORKFLOW_STARTED = auto()
|
||||
WORKFLOW_FINISHED = auto()
|
||||
NODE_STARTED = auto()
|
||||
NODE_FINISHED = auto()
|
||||
NODE_RETRY = auto()
|
||||
PARALLEL_BRANCH_STARTED = auto()
|
||||
PARALLEL_BRANCH_FINISHED = auto()
|
||||
ITERATION_STARTED = auto()
|
||||
ITERATION_NEXT = auto()
|
||||
ITERATION_COMPLETED = auto()
|
||||
LOOP_STARTED = auto()
|
||||
LOOP_NEXT = auto()
|
||||
LOOP_COMPLETED = auto()
|
||||
TEXT_CHUNK = auto()
|
||||
TEXT_REPLACE = auto()
|
||||
AGENT_LOG = auto()
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class MessageCycleManager:
|
|||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
if conversation.mode != AppMode.COMPLETION:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class PlanningStrategy(Enum):
|
||||
ROUTER = "router"
|
||||
REACT_ROUTER = "react_router"
|
||||
REACT = "react"
|
||||
FUNCTION_CALL = "function_call"
|
||||
class PlanningStrategy(StrEnum):
|
||||
ROUTER = auto()
|
||||
REACT_ROUTER = auto()
|
||||
REACT = auto()
|
||||
FUNCTION_CALL = auto()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class EmbeddingInputType(Enum):
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""
|
||||
Enum for embedding input type.
|
||||
"""
|
||||
|
||||
DOCUMENT = "document"
|
||||
QUERY = "query"
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
|||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
|
||||
|
||||
class ModelStatus(Enum):
|
||||
class ModelStatus(StrEnum):
|
||||
"""
|
||||
Enum class for model status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
ACTIVE = auto()
|
||||
NO_CONFIGURE = "no-configure"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
DISABLED = auto()
|
||||
CREDENTIAL_REMOVED = "credential-removed"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from enum import StrEnum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class CommonParameterType(StrEnum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
SELECT = auto()
|
||||
STRING = auto()
|
||||
NUMBER = auto()
|
||||
FILE = auto()
|
||||
FILES = auto()
|
||||
SYSTEM_FILES = "system-files"
|
||||
BOOLEAN = "boolean"
|
||||
BOOLEAN = auto()
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
ANY = "any"
|
||||
ANY = auto()
|
||||
|
||||
# Dynamic select parameter
|
||||
# Once you are not sure about the available options until authorization is done
|
||||
|
|
@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):
|
|||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
# MCP object and array type parameters
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class AppSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = "workflow"
|
||||
COMPLETION = "completion"
|
||||
ALL = auto()
|
||||
CHAT = auto()
|
||||
WORKFLOW = auto()
|
||||
COMPLETION = auto()
|
||||
|
||||
|
||||
class ModelSelectorScope(StrEnum):
|
||||
LLM = "llm"
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
VISION = "vision"
|
||||
RERANK = auto()
|
||||
TTS = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
VISION = auto()
|
||||
|
||||
|
||||
class ToolSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CUSTOM = "custom"
|
||||
BUILTIN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
ALL = auto()
|
||||
CUSTOM = auto()
|
||||
BUILTIN = auto()
|
||||
WORKFLOW = auto()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
|
@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -31,20 +31,20 @@ class ProviderQuotaType(Enum):
|
|||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = "times"
|
||||
TOKENS = "tokens"
|
||||
CREDITS = "credits"
|
||||
class QuotaUnit(StrEnum):
|
||||
TIMES = auto()
|
||||
TOKENS = auto()
|
||||
CREDITS = auto()
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
class SystemConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
ACTIVE = auto()
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
UNSUPPORTED = "unsupported"
|
||||
UNSUPPORTED = auto()
|
||||
|
||||
|
||||
class RestrictModel(BaseModel):
|
||||
|
|
@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel):
|
|||
Base model class for common provider settings like credentials
|
||||
"""
|
||||
|
||||
class Type(Enum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
class Type(StrEnum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT
|
||||
SELECT = CommonParameterType.SELECT
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import StrEnum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
|
|
@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
class ExtensionModule(StrEnum):
|
||||
MODERATION = auto()
|
||||
EXTERNAL_DATA_TOOL = auto()
|
||||
|
||||
|
||||
class ModuleExtension(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,3 @@ FILE_MODEL_IDENTITY = "__dify__file__"
|
|||
|
||||
def maybe_file_object(o: Any) -> bool:
|
||||
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
|
||||
|
||||
# The default user ID for service API calls.
|
||||
DEFAULT_SERVICE_API_USER_ID = "DEFAULT-USER"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import os
|
|||
import time
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.constants import DEFAULT_SERVICE_API_USER_ID
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str) -> str:
|
||||
|
|
@ -25,10 +24,6 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
|
|||
# Plugin access should use internal URL for Docker network communication
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
url = f"{base_url}/files/upload/for-plugin"
|
||||
|
||||
if user_id is None:
|
||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
|
|
@ -40,11 +35,8 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
|
|||
|
||||
|
||||
def verify_plugin_file_signature(
|
||||
*, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str
|
||||
*, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str
|
||||
) -> bool:
|
||||
if user_id is None:
|
||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
||||
|
||||
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCacheType(Enum):
|
||||
class ProviderCredentialsCacheType(StrEnum):
|
||||
PROVIDER = "provider"
|
||||
MODEL = "provider_model"
|
||||
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
|
||||
|
|
@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum):
|
|||
|
||||
class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import TypeVar
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping from name to index from a YAML file
|
||||
|
|
@ -14,12 +16,17 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
|||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
# FIXME(-LAN-): Cache position maps to prevent file descriptor exhaustion during high-load benchmarks
|
||||
position_file_path = os.path.join(folder_path, file_name)
|
||||
yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
|
||||
try:
|
||||
yaml_content = load_yaml_file_cached(file_path=position_file_path)
|
||||
except Exception:
|
||||
yaml_content = []
|
||||
positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
|
||||
return {name: index for index, name in enumerate(positions)}
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for tools from name to index from a YAML file.
|
||||
|
|
@ -35,20 +42,6 @@ def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -
|
|||
)
|
||||
|
||||
|
||||
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for providers from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Pin the items in the pin list to the beginning of the position map.
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolParameterCacheType(Enum):
|
||||
class ToolParameterCacheType(StrEnum):
|
||||
PARAMETER = "tool_parameter"
|
||||
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ class ToolParameterCache:
|
|||
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
|
||||
):
|
||||
self.cache_key = (
|
||||
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
f":identity_id:{identity_id}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ def handle_call_tool(
|
|||
end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT.value,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT,
|
||||
)
|
||||
|
||||
answer = extract_answer_from_response(app, response)
|
||||
|
|
@ -157,7 +157,7 @@ def build_parameter_schema(
|
|||
"""Build parameter schema for the tool"""
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
|
|
@ -175,9 +175,9 @@ def build_parameter_schema(
|
|||
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW.value:
|
||||
if app.mode == AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION.value:
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
|
|
@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
|
|||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
AppMode.ADVANCED_CHAT,
|
||||
AppMode.COMPLETION,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW.value:
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
class PromptMessageRole(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
SYSTEM = auto()
|
||||
USER = auto()
|
||||
ASSISTANT = auto()
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||
|
|
@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
|
|||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
|
|
@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|||
"""
|
||||
|
||||
class DETAIL(StrEnum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from decimal import Decimal
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
|
@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
|
|||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
class ModelType(StrEnum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = "llm"
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
RERANK = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
|
|
@ -26,17 +26,17 @@ class ModelType(Enum):
|
|||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type in {"text-generation", cls.LLM.value}:
|
||||
if origin_model_type in {"text-generation", cls.LLM}:
|
||||
return cls.LLM
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type in {"reranking", cls.RERANK.value}:
|
||||
elif origin_model_type in {"reranking", cls.RERANK}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||
elif origin_model_type in {"tts", cls.TTS}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
elif origin_model_type == cls.MODERATION:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
|
|
@ -63,7 +63,7 @@ class ModelType(Enum):
|
|||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
class FetchFrom(StrEnum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
|
|
@ -72,7 +72,7 @@ class FetchFrom(Enum):
|
|||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
class ModelFeature(StrEnum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
|
|
@ -80,11 +80,11 @@ class ModelFeature(Enum):
|
|||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
|
|
@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
|
|||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
TOP_K = "top_k"
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
RESPONSE_FORMAT = "response_format"
|
||||
JSON_SCHEMA = "json_schema"
|
||||
TEMPERATURE = auto()
|
||||
TOP_P = auto()
|
||||
TOP_K = auto()
|
||||
PRESENCE_PENALTY = auto()
|
||||
FREQUENCY_PENALTY = auto()
|
||||
MAX_TOKENS = auto()
|
||||
RESPONSE_FORMAT = auto()
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||
|
|
@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
|
|||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(Enum):
|
||||
class ParameterType(StrEnum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = "float"
|
||||
INT = "int"
|
||||
STRING = "string"
|
||||
BOOLEAN = "boolean"
|
||||
TEXT = "text"
|
||||
FLOAT = auto()
|
||||
INT = auto()
|
||||
STRING = auto()
|
||||
BOOLEAN = auto()
|
||||
TEXT = auto()
|
||||
|
||||
|
||||
class ModelPropertyKey(Enum):
|
||||
class ModelPropertyKey(StrEnum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = "mode"
|
||||
CONTEXT_SIZE = "context_size"
|
||||
MAX_CHUNKS = "max_chunks"
|
||||
FILE_UPLOAD_LIMIT = "file_upload_limit"
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
VOICES = "voices"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDIO_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
MODE = auto()
|
||||
CONTEXT_SIZE = auto()
|
||||
MAX_CHUNKS = auto()
|
||||
FILE_UPLOAD_LIMIT = auto()
|
||||
SUPPORTED_FILE_EXTENSIONS = auto()
|
||||
MAX_CHARACTERS_PER_CHUNK = auto()
|
||||
DEFAULT_VOICE = auto()
|
||||
VOICES = auto()
|
||||
WORD_LIMIT = auto()
|
||||
AUDIO_TYPE = auto()
|
||||
MAX_WORKERS = auto()
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
|
|
@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
|
|||
pass
|
||||
|
||||
|
||||
class PriceType(Enum):
|
||||
class PriceType(StrEnum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = "input"
|
||||
OUTPUT = "output"
|
||||
INPUT = auto()
|
||||
OUTPUT = auto()
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
|
@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
|
|||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(Enum):
|
||||
class FormType(StrEnum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = "select"
|
||||
RADIO = "radio"
|
||||
SWITCH = "switch"
|
||||
SELECT = auto()
|
||||
RADIO = auto()
|
||||
SWITCH = auto()
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class TextEmbeddingModel(AIModel):
|
|||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type.value,
|
||||
input_type=input_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import contexts
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
|
@ -26,50 +22,22 @@ from models.provider_ids import ModelProviderID
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProviderExtension(BaseModel):
|
||||
plugin_model_provider_entity: PluginModelProviderEntity
|
||||
position: Optional[int] = None
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
provider_position_map: dict[str, int]
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
def __init__(self, tenant_id: str):
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
self.provider_position_map = {}
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelClient()
|
||||
|
||||
if not self.provider_position_map:
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
model_providers_path = os.path.dirname(current_path)
|
||||
|
||||
# get _position.yaml file path
|
||||
self.provider_position_map = get_provider_position_map(model_providers_path)
|
||||
|
||||
def get_providers(self) -> Sequence[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
"""
|
||||
# Fetch plugin model providers
|
||||
# FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server
|
||||
# The plugin server should return providers in the desired order
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
|
||||
# Convert PluginModelProviderEntity to ModelProviderExtension
|
||||
model_provider_extensions = []
|
||||
for provider in plugin_providers:
|
||||
model_provider_extensions.append(ModelProviderExtension(plugin_model_provider_entity=provider))
|
||||
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=self.provider_position_map,
|
||||
data=model_provider_extensions,
|
||||
name_func=lambda x: x.plugin_model_provider_entity.declaration.provider,
|
||||
)
|
||||
|
||||
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from pydantic_core import Url
|
|||
from pydantic_extra_types.color import Color
|
||||
|
||||
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ def jsonable_encoder(
|
|||
exclude_none: bool = False,
|
||||
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
|
||||
sqlalchemy_safe: bool = True,
|
||||
):
|
||||
) -> Any:
|
||||
custom_encoder = custom_encoder or {}
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -7,9 +7,9 @@ from pydantic import BaseModel, Field
|
|||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ModerationAction(Enum):
|
||||
DIRECT_OUTPUT = "direct_output"
|
||||
OVERRIDDEN = "overridden"
|
||||
class ModerationAction(StrEnum):
|
||||
DIRECT_OUTPUT = auto()
|
||||
OVERRIDDEN = auto()
|
||||
|
||||
|
||||
class ModerationInputsResult(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
|
|
@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description"
|
|||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
class GenAISpanKind(StrEnum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
app = cls._get_app(app_id, tenant_id)
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
|
@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"""
|
||||
invoke chat app
|
||||
"""
|
||||
if app.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
|
@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT.value:
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
|
|
@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.CHAT.value:
|
||||
elif app.mode == AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
import json
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
|
@ -24,44 +24,44 @@ class PluginParameterOption(BaseModel):
|
|||
return value
|
||||
|
||||
|
||||
class PluginParameterType(enum.StrEnum):
|
||||
class PluginParameterType(StrEnum):
|
||||
"""
|
||||
all available parameter types
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
ANY = CommonParameterType.ANY.value
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = CommonParameterType.ARRAY.value
|
||||
OBJECT = CommonParameterType.OBJECT.value
|
||||
ARRAY = CommonParameterType.ARRAY
|
||||
OBJECT = CommonParameterType.OBJECT
|
||||
|
||||
|
||||
class MCPServerParameterType(enum.StrEnum):
|
||||
class MCPServerParameterType(StrEnum):
|
||||
"""
|
||||
MCP server got complex parameter types
|
||||
"""
|
||||
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class PluginParameterAutoGenerate(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
PROMPT_INSTRUCTION = "prompt_instruction"
|
||||
class Type(StrEnum):
|
||||
PROMPT_INSTRUCTION = auto()
|
||||
|
||||
type: Type
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ class PluginParameter(BaseModel):
|
|||
return v
|
||||
|
||||
|
||||
def as_normal_type(typ: enum.StrEnum):
|
||||
def as_normal_type(typ: StrEnum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
|
|
@ -101,7 +101,7 @@ def as_normal_type(typ: enum.StrEnum):
|
|||
return typ.value
|
||||
|
||||
|
||||
def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
|
||||
|
|
@ -189,7 +189,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any):
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
"""
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
|
@ -14,11 +14,11 @@ from core.tools.entities.common_entities import I18nObject
|
|||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(enum.StrEnum):
|
||||
Github = "github"
|
||||
Marketplace = "marketplace"
|
||||
Package = "package"
|
||||
Remote = "remote"
|
||||
class PluginInstallationSource(StrEnum):
|
||||
Github = auto()
|
||||
Marketplace = auto()
|
||||
Package = auto()
|
||||
Remote = auto()
|
||||
|
||||
|
||||
class PluginResourceRequirements(BaseModel):
|
||||
|
|
@ -56,10 +56,10 @@ class PluginResourceRequirements(BaseModel):
|
|||
permission: Optional[Permission] = Field(default=None)
|
||||
|
||||
|
||||
class PluginCategory(enum.StrEnum):
|
||||
Tool = "tool"
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
class PluginCategory(StrEnum):
|
||||
Tool = auto()
|
||||
Model = auto()
|
||||
Extension = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
|
||||
|
||||
|
|
@ -155,10 +155,10 @@ class PluginEntity(PluginInstallation):
|
|||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
Marketplace = PluginInstallationSource.Marketplace.value
|
||||
Package = PluginInstallationSource.Package.value
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github
|
||||
Marketplace = PluginInstallationSource.Marketplace
|
||||
Package = PluginInstallationSource.Package
|
||||
|
||||
class Github(BaseModel):
|
||||
repo: str
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import enum
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
|
|
@ -25,9 +25,9 @@ if TYPE_CHECKING:
|
|||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(enum.StrEnum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
class ModelMode(StrEnum):
|
||||
COMPLETION = auto()
|
||||
CHAT = auto()
|
||||
|
||||
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class Field(Enum):
|
||||
class Field(StrEnum):
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
VECTOR = auto()
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = "sparse_vector"
|
||||
SPARSE_VECTOR = auto()
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
DOC_ID = "metadata.doc_id"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
|
|
@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel):
|
|||
fts_params: str
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
class SortOrder(StrEnum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class DatasourceType(Enum):
|
||||
class DatasourceType(StrEnum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
||||
WEBSITE = "website_crawl"
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
|
||||
class BuiltInField(StrEnum):
|
||||
document_name = "document_name"
|
||||
uploader = "uploader"
|
||||
upload_date = "upload_date"
|
||||
last_update_date = "last_update_date"
|
||||
source = "source"
|
||||
document_name = auto()
|
||||
uploader = auto()
|
||||
upload_date = auto()
|
||||
last_update_date = auto()
|
||||
source = auto()
|
||||
|
||||
|
||||
class MetadataDataSource(Enum):
|
||||
class MetadataDataSource(StrEnum):
|
||||
upload_file = "file_upload"
|
||||
website_crawl = "website"
|
||||
notion_import = "notion"
|
||||
|
|
|
|||
|
|
@ -113,21 +113,33 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
# node_ids is segment's node_ids
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
vector = Vector(dataset)
|
||||
|
||||
if node_ids:
|
||||
child_node_ids = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
# Use precomputed child_node_ids if available (to avoid race conditions)
|
||||
if precomputed_child_node_ids is not None:
|
||||
child_node_ids = precomputed_child_node_ids
|
||||
else:
|
||||
# Fallback to original query (may fail if segments are already deleted)
|
||||
child_node_ids = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
|
||||
vector.delete_by_ids(child_node_ids)
|
||||
if delete_child_chunks:
|
||||
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
|
||||
|
||||
# Delete from vector index
|
||||
if child_node_ids:
|
||||
vector.delete_by_ids(child_node_ids)
|
||||
|
||||
# Delete from database
|
||||
if delete_child_chunks and child_node_ids:
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
|||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
|
|
@ -31,7 +31,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
provider = self.__class__.__module__.split(".")[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
|
||||
try:
|
||||
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
|
||||
provider_yaml = load_yaml_file_cached(yaml_path)
|
||||
except Exception as e:
|
||||
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
for tool_file in tool_files:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
|
||||
tool = load_yaml_file_cached(path.join(tool_path, tool_file))
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class: type = load_single_subclass_from_source(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import base64
|
||||
import contextlib
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
|
|
@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject
|
|||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
class ToolLabelEnum(StrEnum):
|
||||
SEARCH = auto()
|
||||
IMAGE = auto()
|
||||
VIDEOS = auto()
|
||||
WEATHER = auto()
|
||||
FINANCE = auto()
|
||||
DESIGN = auto()
|
||||
TRAVEL = auto()
|
||||
SOCIAL = auto()
|
||||
NEWS = auto()
|
||||
MEDICAL = auto()
|
||||
PRODUCTIVITY = auto()
|
||||
EDUCATION = auto()
|
||||
BUSINESS = auto()
|
||||
ENTERTAINMENT = auto()
|
||||
UTILITIES = auto()
|
||||
OTHER = auto()
|
||||
|
||||
|
||||
class ToolProviderType(enum.StrEnum):
|
||||
class ToolProviderType(StrEnum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = "plugin"
|
||||
PLUGIN = auto()
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
APP = "app"
|
||||
WORKFLOW = auto()
|
||||
API = auto()
|
||||
APP = auto()
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = "mcp"
|
||||
MCP = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
|
|
@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum):
|
|||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
class ApiProviderSchemaType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
OPENAPI = auto()
|
||||
SWAGGER = auto()
|
||||
OPENAI_PLUGIN = auto()
|
||||
OPENAI_ACTIONS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
|
|
@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum):
|
|||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
class ApiProviderAuthType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY_HEADER = "api_key_header"
|
||||
API_KEY_QUERY = "api_key_query"
|
||||
NONE = auto()
|
||||
API_KEY_HEADER = auto()
|
||||
API_KEY_QUERY = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
|
|
@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel):
|
|||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
class LogStatus(StrEnum):
|
||||
START = auto()
|
||||
ERROR = auto()
|
||||
SUCCESS = auto()
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
|
|
@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel):
|
|||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
BLOB_CHUNK = "blob_chunk"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
class MessageType(StrEnum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
LINK = auto()
|
||||
BLOB = auto()
|
||||
JSON = auto()
|
||||
IMAGE_LINK = auto()
|
||||
BINARY_LINK = auto()
|
||||
VARIABLE = auto()
|
||||
FILE = auto()
|
||||
LOG = auto()
|
||||
BLOB_CHUNK = auto()
|
||||
RETRIEVER_RESOURCES = auto()
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
|
|
@ -250,29 +249,29 @@ class ToolParameter(PluginParameter):
|
|||
Overrides type
|
||||
"""
|
||||
|
||||
class ToolParameterType(enum.StrEnum):
|
||||
class ToolParameterType(StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
ANY = PluginParameterType.ANY.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
STRING = PluginParameterType.STRING
|
||||
NUMBER = PluginParameterType.NUMBER
|
||||
BOOLEAN = PluginParameterType.BOOLEAN
|
||||
SELECT = PluginParameterType.SELECT
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
|
||||
ANY = PluginParameterType.ANY
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = MCPServerParameterType.ARRAY.value
|
||||
OBJECT = MCPServerParameterType.OBJECT.value
|
||||
ARRAY = MCPServerParameterType.ARRAY
|
||||
OBJECT = MCPServerParameterType.OBJECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
|
@ -280,10 +279,10 @@ class ToolParameter(PluginParameter):
|
|||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
class ToolParameterForm(StrEnum):
|
||||
SCHEMA = auto() # should be set while adding tool
|
||||
FORM = auto() # should be set before invoking tool
|
||||
LLM = auto() # will be set by LLM
|
||||
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: I18nObject | None = Field(default=None, description="The description presented to the user")
|
||||
|
|
@ -448,14 +447,14 @@ class ToolLabel(BaseModel):
|
|||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
class ToolInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
WORKFLOW = auto()
|
||||
AGENT = auto()
|
||||
PLUGIN = auto()
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
|
|
@ -480,9 +479,9 @@ class ToolSelector(BaseModel):
|
|||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
class CredentialType(StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
OAUTH2 = auto()
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -8,28 +9,25 @@ from yaml import YAMLError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}):
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
def _load_yaml_file(*, file_path: str):
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
return yaml_content
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def load_yaml_file_cached(file_path: str) -> Any:
|
||||
"""
|
||||
Cached version of load_yaml_file for static configuration files.
|
||||
Only use for files that don't change during runtime (e.g., position files)
|
||||
|
||||
:param file_path: the path of the YAML file
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
return _load_yaml_file(file_path=file_path)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from enum import Enum, StrEnum
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
|
|||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
|
|
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
|
|||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
|
|||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
|
|
@ -282,7 +282,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
|
|
|
|||
|
|
@ -397,7 +397,7 @@ class WorkflowEntry:
|
|||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||
|
||||
# environment variable already exist in variable pool, not from user inputs
|
||||
if variable_pool.get(variable_selector):
|
||||
if variable_pool.get(variable_selector) and variable_selector[0] == ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
continue
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
|
|
|
|||
|
|
@ -9,19 +9,19 @@ import json
|
|||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
class FileStatus(StrEnum):
|
||||
"""File status enumeration"""
|
||||
|
||||
ACTIVE = "active" # Active status
|
||||
ARCHIVED = "archived" # Archived
|
||||
DELETED = "deleted" # Deleted (soft delete)
|
||||
BACKUP = "backup" # Backup file
|
||||
ACTIVE = auto() # Active status
|
||||
ARCHIVED = auto() # Archived
|
||||
DELETED = auto() # Deleted (soft delete)
|
||||
BACKUP = auto() # Backup file
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ According to ClickZetta's permission model, different Volume types have differen
|
|||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(Enum):
|
||||
class VolumePermission(StrEnum):
|
||||
"""Volume permission type enumeration"""
|
||||
|
||||
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ eliminates the need for repetitive language switching logic.
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from flask import render_template
|
||||
|
|
@ -17,26 +17,30 @@ from extensions.ext_mail import mail
|
|||
from services.feature_service import BrandingModel, FeatureService
|
||||
|
||||
|
||||
class EmailType(Enum):
|
||||
class EmailType(StrEnum):
|
||||
"""Enumeration of supported email types."""
|
||||
|
||||
RESET_PASSWORD = "reset_password"
|
||||
INVITE_MEMBER = "invite_member"
|
||||
EMAIL_CODE_LOGIN = "email_code_login"
|
||||
CHANGE_EMAIL_OLD = "change_email_old"
|
||||
CHANGE_EMAIL_NEW = "change_email_new"
|
||||
CHANGE_EMAIL_COMPLETED = "change_email_completed"
|
||||
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
|
||||
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
|
||||
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
|
||||
ACCOUNT_DELETION_SUCCESS = "account_deletion_success"
|
||||
ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification"
|
||||
ENTERPRISE_CUSTOM = "enterprise_custom"
|
||||
QUEUE_MONITOR_ALERT = "queue_monitor_alert"
|
||||
DOCUMENT_CLEAN_NOTIFY = "document_clean_notify"
|
||||
RESET_PASSWORD = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
|
||||
INVITE_MEMBER = auto()
|
||||
EMAIL_CODE_LOGIN = auto()
|
||||
CHANGE_EMAIL_OLD = auto()
|
||||
CHANGE_EMAIL_NEW = auto()
|
||||
CHANGE_EMAIL_COMPLETED = auto()
|
||||
OWNER_TRANSFER_CONFIRM = auto()
|
||||
OWNER_TRANSFER_OLD_NOTIFY = auto()
|
||||
OWNER_TRANSFER_NEW_NOTIFY = auto()
|
||||
ACCOUNT_DELETION_SUCCESS = auto()
|
||||
ACCOUNT_DELETION_VERIFICATION = auto()
|
||||
ENTERPRISE_CUSTOM = auto()
|
||||
QUEUE_MONITOR_ALERT = auto()
|
||||
DOCUMENT_CLEAN_NOTIFY = auto()
|
||||
EMAIL_REGISTER = auto()
|
||||
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
|
||||
|
||||
|
||||
class EmailLanguage(Enum):
|
||||
class EmailLanguage(StrEnum):
|
||||
"""Supported email languages with fallback handling."""
|
||||
|
||||
EN_US = "en-US"
|
||||
|
|
@ -441,6 +445,54 @@ def create_default_email_config() -> EmailI18nConfig:
|
|||
branded_template_path="clean_document_job_mail_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.EMAIL_REGISTER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Register Your {application_title} Account",
|
||||
template_path="register_email_template_en-US.html",
|
||||
branded_template_path="without-brand/register_email_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="注册您的 {application_title} 账户",
|
||||
template_path="register_email_template_zh-CN.html",
|
||||
branded_template_path="without-brand/register_email_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Register Your {application_title} Account",
|
||||
template_path="register_email_when_account_exist_template_en-US.html",
|
||||
branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="注册您的 {application_title} 账户",
|
||||
template_path="register_email_when_account_exist_template_zh-CN.html",
|
||||
branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Reset Your {application_title} Password",
|
||||
template_path="reset_password_mail_when_account_not_exist_template_en-US.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="重置您的 {application_title} 密码",
|
||||
template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Reset Your {application_title} Password",
|
||||
template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="重置您的 {application_title} 密码",
|
||||
template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
|
||||
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
return EmailI18nConfig(templates=templates)
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw):
|
|||
if isinstance(obj, dict) and "app" in obj:
|
||||
obj = obj["app"]
|
||||
|
||||
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
|
||||
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(obj.icon)
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -224,35 +224,35 @@ class Dataset(Base):
|
|||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.document_name.value,
|
||||
"name": BuiltInField.document_name,
|
||||
"type": "string",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.uploader.value,
|
||||
"name": BuiltInField.uploader,
|
||||
"type": "string",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.upload_date.value,
|
||||
"name": BuiltInField.upload_date,
|
||||
"type": "time",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.last_update_date.value,
|
||||
"name": BuiltInField.last_update_date,
|
||||
"type": "time",
|
||||
}
|
||||
)
|
||||
doc_metadata.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
"name": BuiltInField.source.value,
|
||||
"name": BuiltInField.source,
|
||||
"type": "string",
|
||||
}
|
||||
)
|
||||
|
|
@ -544,7 +544,7 @@ class Document(Base):
|
|||
"id": "built-in",
|
||||
"name": BuiltInField.source,
|
||||
"type": "string",
|
||||
"value": MetadataDataSource[self.data_source_type].value,
|
||||
"value": MetadataDataSource[self.data_source_type],
|
||||
}
|
||||
)
|
||||
return built_in_fields
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import re
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -60,9 +60,9 @@ class AppMode(StrEnum):
|
|||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class IconType(Enum):
|
||||
IMAGE = "image"
|
||||
EMOJI = "emoji"
|
||||
class IconType(StrEnum):
|
||||
IMAGE = auto()
|
||||
EMOJI = auto()
|
||||
|
||||
|
||||
class App(Base):
|
||||
|
|
@ -147,15 +147,15 @@ class App(Base):
|
|||
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
|
||||
"strategy", ""
|
||||
) in {"function_call", "react"}:
|
||||
self.mode = AppMode.AGENT_CHAT.value
|
||||
self.mode = AppMode.AGENT_CHAT
|
||||
db.session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def mode_compatible_with_agent(self) -> str:
|
||||
if self.mode == AppMode.CHAT.value and self.is_agent:
|
||||
return AppMode.AGENT_CHAT.value
|
||||
if self.mode == AppMode.CHAT and self.is_agent:
|
||||
return AppMode.AGENT_CHAT
|
||||
|
||||
return str(self.mode)
|
||||
|
||||
|
|
@ -712,7 +712,7 @@ class Conversation(Base):
|
|||
model_config = {}
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
model_config = override_model_configs
|
||||
|
|
@ -1459,6 +1459,14 @@ class OperationLog(Base):
|
|||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DefaultEndUserSessionID(StrEnum):
|
||||
"""
|
||||
End User Session ID enum.
|
||||
"""
|
||||
|
||||
DEFAULT_SESSION_ID = "DEFAULT-USER"
|
||||
|
||||
|
||||
class EndUser(Base, UserMixin):
|
||||
__tablename__ = "end_users"
|
||||
__table_args__ = (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -12,9 +12,9 @@ from .engine import db
|
|||
from .types import StringUUID
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
CUSTOM = "custom"
|
||||
SYSTEM = "system"
|
||||
class ProviderType(StrEnum):
|
||||
CUSTOM = auto()
|
||||
SYSTEM = auto()
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> "ProviderType":
|
||||
|
|
@ -24,14 +24,14 @@ class ProviderType(Enum):
|
|||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -41,13 +41,13 @@ from .types import EnumText, StringUUID
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowType(Enum):
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = auto()
|
||||
CHAT = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
|
|
@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base):
|
|||
return extras
|
||||
|
||||
|
||||
class WorkflowAppLogCreatedFrom(Enum):
|
||||
class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
"""
|
||||
Workflow App Log Created From Enum
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -169,6 +169,8 @@ dev = [
|
|||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy~=1.17.1",
|
||||
"locust>=2.40.4",
|
||||
"sseclient-py>=1.8.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ from services.billing_service import BillingService
|
|||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountLoginError,
|
||||
AccountNotFoundError,
|
||||
AccountNotLinkTenantError,
|
||||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
|
|
@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import (
|
|||
send_old_owner_transfer_notify_email_task,
|
||||
send_owner_transfer_confirm_task,
|
||||
)
|
||||
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
|
||||
from tasks.mail_reset_password_task import (
|
||||
send_reset_password_mail_task,
|
||||
send_reset_password_mail_task_when_account_not_exist,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
|||
|
||||
class AccountService:
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
|
||||
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
|
||||
)
|
||||
email_code_account_deletion_rate_limiter = RateLimiter(
|
||||
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
|
||||
|
|
@ -95,6 +99,7 @@ class AccountService:
|
|||
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
|
||||
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
|
||||
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
|
||||
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
|
||||
|
||||
@staticmethod
|
||||
def _get_refresh_token_key(refresh_token: str) -> str:
|
||||
|
|
@ -171,7 +176,7 @@ class AccountService:
|
|||
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
raise AccountLoginError("Account is banned.")
|
||||
|
|
@ -296,7 +301,9 @@ class AccountService:
|
|||
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
|
||||
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
|
||||
|
||||
raise EmailCodeAccountDeletionRateLimitExceededError()
|
||||
raise EmailCodeAccountDeletionRateLimitExceededError(
|
||||
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
|
||||
)
|
||||
|
||||
send_account_deletion_verification_code.delay(to=email, code=code)
|
||||
|
||||
|
|
@ -435,6 +442,7 @@ class AccountService:
|
|||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
language: str = "en-US",
|
||||
is_allow_register: bool = False,
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
|
|
@ -443,18 +451,59 @@ class AccountService:
|
|||
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import PasswordResetRateLimitExceededError
|
||||
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_reset_password_token(account_email, account)
|
||||
|
||||
send_reset_password_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
code=code,
|
||||
)
|
||||
if account:
|
||||
send_reset_password_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
code=code,
|
||||
)
|
||||
else:
|
||||
send_reset_password_mail_task_when_account_not_exist.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
is_allow_register=is_allow_register,
|
||||
)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def send_email_register_email(
|
||||
cls,
|
||||
account: Optional[Account] = None,
|
||||
email: Optional[str] = None,
|
||||
language: str = "en-US",
|
||||
):
|
||||
account_email = account.email if account else email
|
||||
if account_email is None:
|
||||
raise ValueError("Email must be provided.")
|
||||
|
||||
if cls.email_register_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
|
||||
|
||||
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_email_register_token(account_email)
|
||||
|
||||
if account:
|
||||
send_email_register_mail_task_when_account_exist.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
account_name=account.name,
|
||||
)
|
||||
|
||||
else:
|
||||
send_email_register_mail_task.delay(
|
||||
language=language,
|
||||
to=account_email,
|
||||
code=code,
|
||||
)
|
||||
cls.email_register_rate_limiter.increment_rate_limit(account_email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def send_change_email_email(
|
||||
cls,
|
||||
|
|
@ -473,7 +522,7 @@ class AccountService:
|
|||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError()
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
|
||||
|
|
@ -517,7 +566,7 @@ class AccountService:
|
|||
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
|
||||
|
||||
raise OwnerTransferRateLimitExceededError()
|
||||
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_owner_transfer_token(account_email, account)
|
||||
workspace_name = workspace_name or ""
|
||||
|
|
@ -587,6 +636,19 @@ class AccountService:
|
|||
)
|
||||
return code, token
|
||||
|
||||
@classmethod
|
||||
def generate_email_register_token(
|
||||
cls,
|
||||
email: str,
|
||||
code: Optional[str] = None,
|
||||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
additional_data["code"] = code
|
||||
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
|
||||
return code, token
|
||||
|
||||
@classmethod
|
||||
def generate_change_email_token(
|
||||
cls,
|
||||
|
|
@ -625,6 +687,10 @@ class AccountService:
|
|||
def revoke_reset_password_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def revoke_email_register_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "email_register")
|
||||
|
||||
@classmethod
|
||||
def revoke_change_email_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, "change_email")
|
||||
|
|
@ -637,6 +703,10 @@ class AccountService:
|
|||
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def get_email_register_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, "email_register")
|
||||
|
||||
@classmethod
|
||||
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, "change_email")
|
||||
|
|
@ -658,7 +728,7 @@ class AccountService:
|
|||
if cls.email_code_login_rate_limiter.is_rate_limited(email):
|
||||
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
|
||||
|
||||
raise EmailCodeLoginRateLimitExceededError()
|
||||
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
|
||||
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
|
|
@ -744,6 +814,16 @@ class AccountService:
|
|||
count = int(count) + 1
|
||||
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_email_register_error_rate_limit(email: str) -> None:
|
||||
key = f"email_register_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
count = 0
|
||||
count = int(count) + 1
|
||||
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_forgot_password_error_rate_limit(email: str) -> bool:
|
||||
|
|
@ -763,6 +843,24 @@ class AccountService:
|
|||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_email_register_error_rate_limit(email: str) -> bool:
|
||||
key = f"email_register_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
return False
|
||||
count = int(count)
|
||||
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def reset_email_register_error_rate_limit(email: str):
|
||||
key = f"email_register_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_change_email_error_rate_limit(email: str):
|
||||
|
|
|
|||
|
|
@ -32,14 +32,14 @@ class AdvancedPromptTemplateService:
|
|||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
|
|
@ -73,7 +73,7 @@ class AdvancedPromptTemplateService:
|
|||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
|
|
@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
|
|||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class AppGenerateService:
|
|||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
|
|
@ -69,7 +69,7 @@ class AppGenerateService:
|
|||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
|
|
@ -78,7 +78,7 @@ class AppGenerateService:
|
|||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
elif app_model.mode == AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
|
|
@ -87,7 +87,7 @@ class AppGenerateService:
|
|||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
|
|
@ -103,7 +103,7 @@ class AppGenerateService:
|
|||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
|
|
@ -154,14 +154,14 @@ class AppGenerateService:
|
|||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
|
|
@ -173,14 +173,14 @@ class AppGenerateService:
|
|||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
|
|
|
|||
|
|
@ -40,15 +40,15 @@ class AppService:
|
|||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
if args["mode"] == "workflow":
|
||||
filters.append(App.mode == AppMode.WORKFLOW.value)
|
||||
filters.append(App.mode == AppMode.WORKFLOW)
|
||||
elif args["mode"] == "completion":
|
||||
filters.append(App.mode == AppMode.COMPLETION.value)
|
||||
filters.append(App.mode == AppMode.COMPLETION)
|
||||
elif args["mode"] == "chat":
|
||||
filters.append(App.mode == AppMode.CHAT.value)
|
||||
filters.append(App.mode == AppMode.CHAT)
|
||||
elif args["mode"] == "advanced-chat":
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT)
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
|
|
@ -171,7 +171,7 @@ class AppService:
|
|||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get original app model config
|
||||
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
|
||||
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
|
||||
model_config = app.app_model_config
|
||||
if not model_config:
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
|||
class AudioService:
|
||||
@classmethod
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
|
@ -88,7 +88,7 @@ class AudioService:
|
|||
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
|
||||
with app.app_context():
|
||||
if voice is None:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if is_draft:
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -222,8 +222,8 @@ class ConversationService:
|
|||
# Filter for variables created after the last_id
|
||||
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
|
||||
|
||||
# Apply limit to query
|
||||
query_stmt = stmt.limit(limit) # Get one extra to check if there are more
|
||||
# Apply limit to query: fetch one extra row to determine has_more
|
||||
query_stmt = stmt.limit(limit + 1)
|
||||
rows = session.scalars(query_stmt).all()
|
||||
|
||||
has_more = False
|
||||
|
|
|
|||
|
|
@ -1004,7 +1004,7 @@ class DocumentService:
|
|||
if dataset.built_in_field_enabled:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata[BuiltInField.document_name.value] = name
|
||||
doc_metadata[BuiltInField.document_name] = name
|
||||
document.doc_metadata = doc_metadata
|
||||
|
||||
document.name = name
|
||||
|
|
@ -2365,7 +2365,22 @@ class SegmentService:
|
|||
if segment.enabled:
|
||||
# send delete segment index task
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
|
||||
|
||||
# Get child chunk IDs before parent segment is deleted
|
||||
child_node_ids = []
|
||||
if segment.index_node_id:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
|
||||
|
||||
db.session.delete(segment)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
|
|
@ -2375,9 +2390,13 @@ class SegmentService:
|
|||
|
||||
@classmethod
|
||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
|
||||
assert current_user is not None
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
segments_info = (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
|
||||
.where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
|
|
@ -2387,18 +2406,36 @@ class SegmentService:
|
|||
.all()
|
||||
)
|
||||
|
||||
if not segments:
|
||||
if not segments_info:
|
||||
return
|
||||
|
||||
index_node_ids = [seg.index_node_id for seg in segments]
|
||||
total_words = sum(seg.word_count for seg in segments)
|
||||
index_node_ids = [info[0] for info in segments_info]
|
||||
segment_db_ids = [info[1] for info in segments_info]
|
||||
total_words = sum(info[2] for info in segments_info if info[2] is not None)
|
||||
|
||||
# Get child chunk IDs before parent segments are deleted
|
||||
child_node_ids = []
|
||||
if index_node_ids:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id.in_(segment_db_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
|
||||
|
||||
document.word_count = (
|
||||
document.word_count - total_words if document.word_count and document.word_count > total_words else 0
|
||||
)
|
||||
db.session.add(document)
|
||||
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
|
||||
# Delete database records
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class MessageService:
|
|||
|
||||
model_manager = ModelManager()
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_service = WorkflowService()
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue