mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into fix/app-list-walk-nodes-graceful
This commit is contained in:
commit
7166077230
|
|
@ -106,7 +106,7 @@ jobs:
|
|||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
run: pnpm run type-check:tsgo
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
|
@ -18,6 +19,30 @@ from extensions.ext_database import db
|
|||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
desc: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
|
|
@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]):
|
|||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
"desc": fields.String(description="App description"),
|
||||
"copyright": fields.String(description="Copyright information"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("desc", type=str, location="json")
|
||||
.add_argument("copyright", type=str, location="json")
|
||||
.add_argument("privacy_policy", type=str, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, location="json")
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
desc = payload.desc or ""
|
||||
copy_right = payload.copyright or ""
|
||||
privacy_policy = payload.privacy_policy or ""
|
||||
custom_disclaimer = payload.custom_disclaimer or ""
|
||||
else:
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
desc = site.description or payload.desc or ""
|
||||
copy_right = site.copyright or payload.copyright or ""
|
||||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
|
|
@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
|
|||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
language=payload.language,
|
||||
category=payload.category,
|
||||
position=payload.position,
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
|
|
@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
|
|||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
recommended_app.language = payload.language
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
app.is_public = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -8,10 +10,21 @@ from libs.login import login_required
|
|||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
message_id: str = Field(..., description="Message UUID")
|
||||
conversation_id: str = Field(..., description="Conversation UUID")
|
||||
|
||||
@field_validator("message_id", "conversation_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
|
|||
@console_ns.doc("get_agent_logs")
|
||||
@console_ns.doc(description="Get agent execution logs for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||
)
|
||||
|
|
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
|
|||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
args = parser.parse_args()
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -21,22 +22,79 @@ from libs.helper import uuid_value
|
|||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AnnotationReplyPayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||
embedding_provider_name: str = Field(..., description="Embedding provider name")
|
||||
embedding_model_name: str = Field(..., description="Embedding model name")
|
||||
|
||||
|
||||
class AnnotationSettingUpdatePayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold")
|
||||
|
||||
|
||||
class AnnotationListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, description="Page size")
|
||||
keyword: str = Field(default="", description="Search keyword")
|
||||
|
||||
|
||||
class CreateAnnotationPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
question: str | None = Field(default=None, description="Question text")
|
||||
answer: str | None = Field(default=None, description="Answer text")
|
||||
content: str | None = Field(default=None, description="Content text")
|
||||
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class UpdateAnnotationPayload(BaseModel):
|
||||
question: str | None = None
|
||||
answer: str | None = None
|
||||
content: str | None = None
|
||||
annotation_reply: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AnnotationReplyStatusQuery(BaseModel):
|
||||
action: Literal["enable", "disable"]
|
||||
|
||||
|
||||
class AnnotationFilePayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
def reg(model: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AnnotationReplyPayload)
|
||||
reg(AnnotationSettingUpdatePayload)
|
||||
reg(AnnotationListQuery)
|
||||
reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@console_ns.doc("annotation_reply_action")
|
||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
|
||||
@console_ns.response(200, "Action completed successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
|
|
@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
@console_ns.doc("update_annotation_setting")
|
||||
@console_ns.doc(description="Update annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Settings updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
return result, 200
|
||||
|
||||
|
||||
|
|
@ -142,12 +184,7 @@ class AnnotationApi(Resource):
|
|||
@console_ns.doc("list_annotations")
|
||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
|
||||
@console_ns.response(200, "Annotations retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -155,9 +192,10 @@ class AnnotationApi(Resource):
|
|||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
keyword = args.keyword
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
|
|
@ -173,18 +211,7 @@ class AnnotationApi(Resource):
|
|||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -195,16 +222,9 @@ class AnnotationApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
|
@ -256,13 +276,6 @@ class AnnotationExportApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
|
|
@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -146,7 +146,14 @@ class AppApiStatusPayload(BaseModel):
|
|||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
tracing_provider: str | None = Field(default=None, description="Tracing provider")
|
||||
|
||||
@field_validator("tracing_provider")
|
||||
@classmethod
|
||||
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
|
||||
if info.data.get("enabled") and not value:
|
||||
raise ValueError("tracing_provider is required when enabled is True")
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
|
|
@ -324,10 +331,13 @@ class AppListApi(Resource):
|
|||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
|
|||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
yaml_content: str | None = None
|
||||
yaml_url: str | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -61,7 +68,7 @@ class AppImportApi(Resource):
|
|||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -70,15 +77,15 @@ class AppImportApi(Resource):
|
|||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
app_id=args.get("app_id"),
|
||||
import_mode=args.mode,
|
||||
yaml_content=args.yaml_content,
|
||||
yaml_url=args.yaml_url,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
|
|
@ -32,6 +33,27 @@ from services.errors.audio import (
|
|||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
text: str = Field(..., description="Text to convert")
|
||||
voice: str | None = Field(default=None, description="Voice name")
|
||||
streaming: bool | None = Field(default=None, description="Whether to stream audio")
|
||||
|
||||
|
||||
class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
|
|
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
|
|||
@console_ns.doc("chat_message_text_to_speech")
|
||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"text": fields.String(required=True, description="Text to convert to speech"),
|
||||
"voice": fields.String(description="Voice to use for TTS"),
|
||||
"streaming": fields.Boolean(description="Whether to stream the audio"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
|
||||
@console_ns.response(200, "Text to speech conversion successful")
|
||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
|
|
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
payload = TextToSpeechPayload.model_validate(console_ns.payload)
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
||||
app_model=app_model,
|
||||
text=payload.text,
|
||||
voice=payload.voice,
|
||||
message_id=payload.message_id,
|
||||
is_draft=True,
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
|
@ -159,9 +164,7 @@ class TextModesApi(Resource):
|
|||
@console_ns.doc("get_text_to_speech_voices")
|
||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||
)
|
||||
|
|
@ -172,12 +175,11 @@ class TextModesApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args["language"],
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
|
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
|
|||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
|
|
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
|
|||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
|
|
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
description = args.get("description")
|
||||
description = payload.description
|
||||
if not description:
|
||||
description = app_model.description or ""
|
||||
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=description,
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
parameters=json.dumps(payload.parameters, ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
|
|||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
"status": fields.String(description="Server status"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
|
|
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
|
|||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("id", type=str, required=True, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
.add_argument("status", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
description = args.get("description")
|
||||
description = payload.description
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
|
|
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
|
|||
else:
|
||||
server.description = description
|
||||
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
server.status = payload.status
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TraceProviderQuery(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
|
||||
|
||||
class TraceConfigPayload(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TraceProviderQuery.__name__,
|
||||
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||
class TraceAppConfigApi(Resource):
|
||||
|
|
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("get_trace_app_config")
|
||||
@console_ns.doc(description="Get tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
|
|
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
|
|
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("create_trace_app_config")
|
||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
|
|
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
|
|
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("update_trace_app_config")
|
||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
|
|
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
|
|
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("delete_trace_app_config")
|
||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
|
|
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
|
|
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
icon_type: str | None = Field(default=None)
|
||||
icon: str | None = Field(default=None)
|
||||
icon_background: str | None = Field(default=None)
|
||||
description: str | None = Field(default=None)
|
||||
default_language: str | None = Field(default=None)
|
||||
chat_color_theme: str | None = Field(default=None)
|
||||
chat_color_theme_inverted: bool | None = Field(default=None)
|
||||
customize_domain: str | None = Field(default=None)
|
||||
copyright: str | None = Field(default=None)
|
||||
privacy_policy: str | None = Field(default=None)
|
||||
custom_disclaimer: str | None = Field(default=None)
|
||||
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
|
||||
prompt_public: bool | None = Field(default=None)
|
||||
show_workflow_steps: bool | None = Field(default=None)
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None)
|
||||
|
||||
@field_validator("default_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppSiteUpdatePayload.__name__,
|
||||
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("title", type=str, required=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, location="json")
|
||||
.add_argument("icon", type=str, required=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
.add_argument("copyright", type=str, required=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
.add_argument(
|
||||
"customize_token_strategy",
|
||||
type=str,
|
||||
choices=["must", "allow", "not_allow"],
|
||||
required=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@console_ns.doc("update_app_site")
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"description": fields.String(description="Site description"),
|
||||
"default_language": fields.String(description="Default language"),
|
||||
"chat_color_theme": fields.String(description="Chat color theme"),
|
||||
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
|
||||
"customize_domain": fields.String(description="Custom domain"),
|
||||
"copyright": fields.String(description="Copyright text"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"customize_token_strategy": fields.String(
|
||||
enum=["must", "allow", "not_allow"], description="Token strategy"
|
||||
),
|
||||
"prompt_public": fields.Boolean(description="Make prompt public"),
|
||||
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
|
@ -89,7 +73,7 @@ class AppSite(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
|
|
@ -113,7 +97,7 @@ class AppSite(Resource):
|
|||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
value = getattr(args, attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
|||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowDraftVariableListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||
|
||||
|
||||
class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Variable name")
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
|
|
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
|||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
|
|
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(_create_pagination_parser())
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_workflow_variables")
|
||||
@console_ns.doc(description="Get draft workflow variables")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
|
|||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -323,15 +328,7 @@ class VariableApi(Resource):
|
|||
|
||||
@console_ns.doc("update_variable")
|
||||
@console_ns.doc(description="Update a workflow variable")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateVariableRequest",
|
||||
{
|
||||
"name": fields.String(description="Variable name"),
|
||||
"value": fields.Raw(description="Variable value"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
|
|
@ -358,16 +355,10 @@ class VariableApi(Resource):
|
|||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
|
|
@ -375,8 +366,8 @@ class VariableApi(Resource):
|
|||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,53 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
|
||||
|
||||
class ActivatePayload(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
name: str = Field(..., max_length=30)
|
||||
interface_language: str = Field(...)
|
||||
timezone: str = Field(...)
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_lang(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_tz(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
|
|||
),
|
||||
)
|
||||
def get(self):
|
||||
args = active_check_parser.parse_args()
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
token = args["token"]
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
if invitation:
|
||||
|
|
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
|
|||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
|
|
@ -85,19 +99,19 @@ class ActivateApi(Resource):
|
|||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = active_parser.parse_args()
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args["name"]
|
||||
account.name = args.name
|
||||
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,26 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
provider: str = Field(...)
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ApiKeyAuthBindingPayload.__name__,
|
||||
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
|
|
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@ from flask import current_app, redirect, request
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
|
|||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EmailRegisterSendPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
language: str | None = Field(default=None, description="Language code")
|
||||
|
||||
|
||||
class EmailRegisterValidityPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
|
|
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args["language"] in languages:
|
||||
language = args["language"]
|
||||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
|
|
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
|
|||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
token_data = AccountService.get_email_register_data(args["token"])
|
||||
token_data = AccountService.get_email_register_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_email_register_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
||||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
|
|
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
|
|||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if args.new_password != args.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get register data
|
||||
register_data = AccountService.get_email_register_data(args["token"])
|
||||
register_data = AccountService.get_email_register_data(args.token)
|
||||
if not register_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
|
|
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
|
||||
|
|
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
|
|||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args["password_confirm"])
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import base64
|
|||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
|||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
|
|
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args["email"],
|
||||
email=args.email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
|
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"code": fields.String(required=True, description="Verification code"),
|
||||
"token": fields.String(required=True, description="Reset token"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
|
|
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
token_data = AccountService.get_reset_password_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
|
|
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
|
|||
class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
"new_password": fields.String(required=True, description="New password"),
|
||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
|
|
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if args.new_password != args.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
reset_data = AccountService.get_reset_password_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
|
|
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
|
|
@ -23,7 +24,7 @@ from controllers.console.error import (
|
|||
)
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
|
|
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
|
|||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
|
||||
class EmailPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class EmailCodeLoginPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(LoginPayload)
|
||||
reg(EmailPayload)
|
||||
reg(EmailCodeLoginPayload)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
|
|
@ -47,41 +78,36 @@ class LoginApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invitation = args["invite_token"]
|
||||
# TODO: why invitation is re-assigned with different type?
|
||||
invitation = args.invite_token # type: ignore
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
data = invitation.get("data", {}) # type: ignore
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args["email"]:
|
||||
if invitee_email != args.email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
|
|
@ -97,7 +123,7 @@ class LoginApi(Resource):
|
|||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
@ -134,25 +160,21 @@ class LogoutApi(Resource):
|
|||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args["email"],
|
||||
email=args.email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
|
|
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
|
|
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
user_email = args.email
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
if token_data["email"] != args.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
if token_data["code"] != args.code:
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError:
|
||||
|
|
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
|
|||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
|||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
|
|
@ -20,15 +21,34 @@ R = TypeVar("R")
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
|
||||
|
||||
class OAuthProviderRequest(BaseModel):
|
||||
client_id: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
class OAuthTokenRequest(BaseModel):
|
||||
client_id: str
|
||||
grant_type: str
|
||||
code: str | None = None
|
||||
client_secret: str | None = None
|
||||
redirect_uri: str | None = None
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
payload = OAuthClientPayload.model_validate(json_data)
|
||||
client_id = payload.client_id
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
|
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
|
|||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
payload = OAuthProviderRequest.model_validate(request.get_json())
|
||||
redirect_uri = payload.redirect_uri
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
|
|
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
|
|||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
payload = OAuthTokenRequest.model_validate(request.get_json())
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
|
|||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import base64
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: str = Field(..., description="Subscription plan")
|
||||
interval: str = Field(..., description="Billing interval")
|
||||
|
||||
@field_validator("plan")
|
||||
@classmethod
|
||||
def validate_plan(cls, value: str) -> str:
|
||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
||||
raise ValueError("Invalid plan")
|
||||
return value
|
||||
|
||||
@field_validator("interval")
|
||||
@classmethod
|
||||
def validate_interval(cls, value: str) -> str:
|
||||
if value not in {"month", "year"}:
|
||||
raise ValueError("Invalid interval")
|
||||
return value
|
||||
|
||||
|
||||
class PartnerTenantsPayload(BaseModel):
|
||||
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||
|
||||
|
||||
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
class Subscription(Resource):
|
||||
|
|
@ -18,20 +49,9 @@ class Subscription(Resource):
|
|||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
|
|
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
|
|||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -9,16 +10,28 @@ from .. import console_ns
|
|||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
doc_name: str = Field(..., description="Compliance document name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ComplianceDownloadQuery.__name__,
|
||||
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/compliance/download")
|
||||
class ComplianceApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
|
||||
@console_ns.doc("download_compliance_document")
|
||||
@console_ns.doc(description="Get compliance document download link")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -35,20 +37,26 @@ recommended_app_list_fields = {
|
|||
}
|
||||
|
||||
|
||||
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RecommendedAppsQuery.__name__,
|
||||
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(parser_apps)
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
def get(self):
|
||||
# language args
|
||||
args = parser_apps.parse_args()
|
||||
|
||||
language = args.get("language")
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import os
|
||||
|
||||
from flask import session
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
|
@ -15,6 +15,18 @@ from . import console_ns
|
|||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InitValidatePayload(BaseModel):
|
||||
password: str = Field(..., max_length=30)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InitValidatePayload.__name__,
|
||||
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/init")
|
||||
class InitValidateAPI(Resource):
|
||||
|
|
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
|
|||
|
||||
@console_ns.doc("validate_init_password")
|
||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InitValidateRequest",
|
||||
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Success",
|
||||
|
|
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
|
|||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
payload = InitValidatePayload.model_validate(console_ns.payload)
|
||||
input_password = payload.password
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
|
|
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
class RemoteFileUploadPayload(BaseModel):
|
||||
url: str = Field(..., description="URL to fetch")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RemoteFileUploadPayload.__name__,
|
||||
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@console_ns.expect(parser_upload)
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
args = parser_upload.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = args.url
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
|
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
|
|||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SetupRequestPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Admin email address")
|
||||
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
|
||||
password: str = Field(..., description="Admin password")
|
||||
language: str | None = Field(default=None, description="Admin language")
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SetupRequestPayload.__name__,
|
||||
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
|
|
@ -42,17 +63,7 @@ class SetupApi(Resource):
|
|||
|
||||
@console_ns.doc("setup_system")
|
||||
@console_ns.doc(description="Initialize system setup with admin account")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SetupRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Admin email address"),
|
||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
||||
"password": fields.String(required=True, description="Admin password"),
|
||||
"language": fields.String(required=False, description="Admin language"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||
)
|
||||
|
|
@ -72,22 +83,15 @@ class SetupApi(Resource):
|
|||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"],
|
||||
name=args["name"],
|
||||
password=args["password"],
|
||||
email=args.email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=args["language"],
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ import json
|
|||
import logging
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
|
@ -11,8 +13,14 @@ from . import console_ns
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
current_version: str = Field(..., description="Current application version")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
VersionQuery.__name__,
|
||||
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
|
|||
class VersionApi(Resource):
|
||||
@console_ns.doc("check_version_update")
|
||||
@console_ns.doc(description="Check for application version updates")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[VersionQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -37,7 +45,7 @@ class VersionApi(Resource):
|
|||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
args = parser.parse_args()
|
||||
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
|
|
@ -57,16 +65,16 @@ class VersionApi(Resource):
|
|||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
params={"current_version": args.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
result["version"] = args.current_version
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, email, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
|
|
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
|
|||
|
||||
|
||||
class AccountDeletionFeedbackPayload(BaseModel):
|
||||
email: str
|
||||
email: EmailStr
|
||||
feedback: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class EducationActivatePayload(BaseModel):
|
||||
token: str
|
||||
|
|
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
|
|||
|
||||
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: str
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailValidityPayload(BaseModel):
|
||||
email: str
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailResetPayload(BaseModel):
|
||||
new_email: str
|
||||
new_email: EmailStr
|
||||
token: str
|
||||
|
||||
@field_validator("new_email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class CheckEmailUniquePayload(BaseModel):
|
||||
email: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
email: EmailStr
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -11,6 +12,26 @@ from extensions.ext_database import db
|
|||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class FileSignatureQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp used in the signature")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
class FilePreviewQuery(FileSignatureQuery):
|
||||
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
files_ns.schema_model(
|
||||
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||
class ImagePreviewApi(Resource):
|
||||
|
|
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
|
|||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
return {"content": "Invalid request."}, 400
|
||||
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
timestamp = args.timestamp
|
||||
nonce = args.nonce
|
||||
sign = args.sign
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||
|
|
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
|
|||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("timestamp", type=str, required=True, location="args")
|
||||
.add_argument("nonce", type=str, required=True, location="args")
|
||||
.add_argument("sign", type=str, required=True, location="args")
|
||||
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
|
||||
return {"content": "Invalid request."}, 400
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
sign=args["sign"],
|
||||
timestamp=args.timestamp,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
|
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
|
|||
response.headers["Accept-Ranges"] = "bytes"
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
|
|
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
|
|||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ToolFileQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp")
|
||||
nonce: str = Field(..., description="Random nonce")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
class ToolFileApi(Resource):
|
||||
|
|
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
|
|||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("timestamp", type=str, required=True, location="args")
|
||||
.add_argument("nonce", type=str, required=True, location="args")
|
||||
.add_argument("sign", type=str, required=True, location="args")
|
||||
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not verify_tool_file_signature(
|
||||
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
|
||||
):
|
||||
args = ToolFileQuery.model_validate(request.args.to_dict())
|
||||
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
|
|
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
|
|||
)
|
||||
if tool_file.size > 0:
|
||||
response.headers["Content-Length"] = str(tool_file.size)
|
||||
if args["as_attachment"]:
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(tool_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,40 +1,45 @@
|
|||
from mimetypes import guess_extension
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.files import files_ns
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from fields.file_fields import build_file_model
|
||||
|
||||
# Define parser for both documentation and validation
|
||||
upload_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
|
||||
.add_argument(
|
||||
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
|
||||
)
|
||||
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
|
||||
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
|
||||
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
|
||||
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
|
||||
from ..common.errors import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from ..console.wraps import setup_required
|
||||
from ..files import files_ns
|
||||
from ..inner_api.plugin.wraps import get_user
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class PluginUploadQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp for signature verification")
|
||||
nonce: str = Field(..., description="Random nonce for signature verification")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
user_id: str | None = Field(default=None, description="User identifier")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/upload/for-plugin")
|
||||
class PluginUploadFileApi(Resource):
|
||||
@setup_required
|
||||
@files_ns.expect(upload_parser)
|
||||
@files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
|
||||
@files_ns.doc("upload_plugin_file")
|
||||
@files_ns.doc(description="Upload a file for plugin usage with signature verification")
|
||||
@files_ns.doc(
|
||||
|
|
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
|
|||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
# Parse and validate all arguments
|
||||
args = upload_parser.parse_args()
|
||||
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
file: FileStorage = args["file"]
|
||||
timestamp: str = args["timestamp"]
|
||||
nonce: str = args["nonce"]
|
||||
sign: str = args["sign"]
|
||||
tenant_id: str = args["tenant_id"]
|
||||
user_id: str | None = args.get("user_id")
|
||||
file: FileStorage | None = request.files.get("file")
|
||||
if file is None:
|
||||
raise Forbidden("File is required.")
|
||||
|
||||
timestamp = args.timestamp
|
||||
nonce = args.nonce
|
||||
sign = args.sign
|
||||
tenant_id = args.tenant_id
|
||||
user_id = args.user_id
|
||||
user = get_user(tenant_id, user_id)
|
||||
|
||||
filename: str | None = file.filename
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
|||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
|
|
@ -98,6 +99,7 @@ class VariableEntityType(StrEnum):
|
|||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
|
|
@ -118,6 +120,7 @@ class VariableEntity(BaseModel):
|
|||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -129,6 +132,17 @@ class VariableEntity(BaseModel):
|
|||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -770,7 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
tts_publisher.publish(None)
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -99,6 +99,15 @@ class BaseAppGenerator:
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Treat empty placeholders for optional file inputs as unset
|
||||
if (
|
||||
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
||||
and not variable_entity.required
|
||||
):
|
||||
# Treat empty string (frontend default) or empty list as unset
|
||||
if not value and isinstance(value, (str, list)):
|
||||
return None
|
||||
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
|
|
|
|||
|
|
@ -156,79 +156,86 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name=conversation_name,
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = naive_utc_now()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name=conversation_name,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.add(message)
|
||||
db.session.flush()
|
||||
db.session.refresh(message)
|
||||
|
||||
message_files = []
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
message_files.append(message_file)
|
||||
|
||||
if message_files:
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
return conversation, message
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
|
|
@ -31,6 +33,7 @@ from core.app.entities.task_entities import (
|
|||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
|
@ -68,6 +71,8 @@ class MessageCycleManager:
|
|||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
time.sleep(1)
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
|
|
@ -76,7 +81,7 @@ class MessageCycleManager:
|
|||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
|
@ -98,15 +103,23 @@ class MessageCycleManager:
|
|||
return
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
conversation.name = name
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
query_hash = hashlib.md5(query.encode()).hexdigest()[:16]
|
||||
cache_key = f"conv_name:{conversation_id}:{query_hash}"
|
||||
|
||||
cached_name = redis_client.get(cache_key)
|
||||
if cached_name:
|
||||
name = cached_name.decode("utf-8")
|
||||
else:
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
redis_client.setex(cache_key, 3600, name)
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
name = query[:47] + "..." if len(query) > 50 else query
|
||||
conversation.name = name
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
return node_span
|
||||
except Exception as e:
|
||||
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
logger.warning("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def build_workflow_task_span(
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags
|
|||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
|
|
@ -48,6 +49,7 @@ class TraceClient:
|
|||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ACS_ARMS_SERVICE_FEATURE: "genai_app",
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
|
|
@ -75,10 +77,10 @@ class TraceClient:
|
|||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
logger.warning("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
logger.warning("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
|
|
@ -116,7 +118,7 @@ class TraceClient:
|
|||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.debug("Error exporting spans: %s", e)
|
||||
logger.warning("Error exporting spans: %s", e)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
with self.condition:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
|
||||
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
|
|
|
|||
|
|
@ -377,20 +377,20 @@ class OpsTraceManager:
|
|||
return app_model_config
|
||||
|
||||
@classmethod
|
||||
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
|
||||
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str | None):
|
||||
"""
|
||||
Update app tracing config
|
||||
:param app_id: app id
|
||||
:param enabled: enabled
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_provider: tracing provider (None when disabling)
|
||||
:return:
|
||||
"""
|
||||
# auth check
|
||||
try:
|
||||
if enabled or tracing_provider is not None:
|
||||
if tracing_provider is not None:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app_config:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from typing import NamedTuple, cast
|
||||
from typing import NamedTuple
|
||||
|
||||
import charset_normalizer
|
||||
|
||||
|
||||
class FileEncoding(NamedTuple):
|
||||
|
|
@ -27,14 +29,14 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
|
|||
sample_size: The number of bytes to read for encoding detection. Default is 1MB.
|
||||
For large files, reading only a sample is sufficient and prevents timeout.
|
||||
"""
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str):
|
||||
with open(file_path, "rb") as f:
|
||||
# Read only a sample of the file for encoding detection
|
||||
# This prevents timeout on large files while still providing accurate encoding detection
|
||||
rawdata = f.read(sample_size)
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
def read_and_detect(filename: str):
|
||||
rst = charset_normalizer.from_path(filename)
|
||||
best = rst.best()
|
||||
if best is None:
|
||||
return []
|
||||
file_encoding = FileEncoding(encoding=best.encoding, confidence=best.coherence, language=best.language)
|
||||
return [file_encoding]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(read_and_detect, file_path)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import charset_normalizer
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
|
|
@ -69,9 +69,12 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
|||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
detected_encoding = charset_normalizer.from_bytes(response.content).best()
|
||||
if detected_encoding:
|
||||
encoding = detected_encoding.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import tempfile
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import chardet
|
||||
import charset_normalizer
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc
|
||||
|
|
@ -228,9 +228,12 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
|||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
|
|
@ -247,9 +250,12 @@ def _extract_text_from_plain_text(file_content: bytes) -> str:
|
|||
|
||||
def _extract_text_from_json(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
|
|
@ -269,9 +275,12 @@ def _extract_text_from_json(file_content: bytes) -> str:
|
|||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
|
|
@ -424,9 +433,12 @@ def _extract_text_from_file(file: File):
|
|||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
|
@ -15,6 +20,7 @@ class StartNode(Node[StartNodeData]):
|
|||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
self._validate_and_normalize_json_object_inputs(node_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
|
|
@ -24,3 +30,27 @@ class StartNode(Node[StartNodeData]):
|
|||
outputs = dict(node_inputs)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||
|
||||
def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None:
|
||||
for variable in self.node_data.variables:
|
||||
if variable.type != VariableEntityType.JSON_OBJECT:
|
||||
continue
|
||||
|
||||
key = variable.variable
|
||||
value = node_inputs.get(key)
|
||||
|
||||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{key} must be a JSON object")
|
||||
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
try:
|
||||
Draft7Validator(schema).validate(value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = value
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
|||
now = datetime_utils.naive_utc_now()
|
||||
last_update = _get_last_update_timestamp(cache_key)
|
||||
|
||||
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
|
||||
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: # type: ignore
|
||||
update_values["last_used"] = values.last_used
|
||||
_set_last_update_timestamp(cache_key, now)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import ssl
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
|
|
@ -245,7 +245,12 @@ def init_app(app: DifyApp):
|
|||
app.extensions["redis"] = redis_client
|
||||
|
||||
|
||||
def redis_fallback(default_return: Any | None = None):
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def redis_fallback(default_return: T | None = None): # type: ignore
|
||||
"""
|
||||
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
||||
|
||||
|
|
@ -253,9 +258,9 @@ def redis_fallback(default_return: Any | None = None):
|
|||
default_return: The value to return when a Redis operation fails. Defaults to None.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
def decorator(func: Callable[P, R]):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RedisError as e:
|
||||
|
|
|
|||
|
|
@ -10,12 +10,13 @@ import uuid
|
|||
from collections.abc import Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel
|
||||
from pydantic.functional_validators import AfterValidator
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
|
|
@ -103,6 +104,9 @@ def email(email):
|
|||
raise ValueError(error)
|
||||
|
||||
|
||||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value):
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ dependencies = [
|
|||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.5.2",
|
||||
"chardet~=5.1.0",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.18",
|
||||
"flask-cors~=6.0.0",
|
||||
|
|
@ -91,6 +91,7 @@ dependencies = [
|
|||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"jsonschema>=4.25.1",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
"tests/",
|
||||
".venv",
|
||||
"migrations/",
|
||||
"core/rag",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
|
|
@ -1259,7 +1259,7 @@ class RegisterService:
|
|||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str):
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
|
||||
"""
|
||||
Setup dify
|
||||
|
||||
|
|
@ -1267,6 +1267,7 @@ class RegisterService:
|
|||
:param name: username
|
||||
:param password: password
|
||||
:param ip_address: ip address
|
||||
:param language: language
|
||||
"""
|
||||
try:
|
||||
account = AccountService.create_account(
|
||||
|
|
@ -1414,7 +1415,7 @@ class RegisterService:
|
|||
return data is not None
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
||||
def revoke_token(cls, workspace_id: str | None, email: str | None, token: str):
|
||||
if workspace_id and email:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||
|
|
@ -1423,7 +1424,9 @@ class RegisterService:
|
|||
redis_client.delete(cls._get_invitation_token_key(token))
|
||||
|
||||
@classmethod
|
||||
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
|
||||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
) -> dict[str, Any] | None:
|
||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -265,3 +265,82 @@ def test_validate_inputs_with_default_value():
|
|||
)
|
||||
|
||||
assert result == [{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}]
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_with_empty_string():
|
||||
"""Test that optional FILE variable with empty string returns None"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file = VariableEntity(
|
||||
variable="test_file",
|
||||
label="test_file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file,
|
||||
value="",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_list_with_empty_list():
|
||||
"""Test that optional FILE_LIST variable with empty list returns None"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file_list = VariableEntity(
|
||||
variable="test_file_list",
|
||||
label="test_file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file_list,
|
||||
value=[],
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_inputs_required_file_with_empty_string_fails():
|
||||
"""Test that required FILE variable with empty string still fails validation"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file = VariableEntity(
|
||||
variable="test_file",
|
||||
label="test_file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
base_app_generator._validate_inputs(
|
||||
variable_entity=var_file,
|
||||
value="",
|
||||
)
|
||||
|
||||
assert "must be a file" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_validate_inputs_optional_file_with_empty_string_ignores_default():
|
||||
"""Test that optional FILE variable with empty string returns None, not the default"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var_file = VariableEntity(
|
||||
variable="test_file",
|
||||
label="test_file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
default={"id": "file123", "name": "default.pdf"},
|
||||
)
|
||||
|
||||
# When value is empty string (from frontend), should return None, not default
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file,
|
||||
value="",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.utils.web_reader_tool import (
|
||||
|
|
@ -103,7 +105,10 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.Monk
|
|||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
|
||||
|
||||
# readability → a dict that maps to Article, then FULL_TEMPLATE
|
||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||
|
|
@ -134,7 +139,9 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.
|
|||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
|
||||
# readability returns empty plain_text
|
||||
monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []})
|
||||
|
||||
|
|
@ -162,7 +169,9 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub
|
|||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
|
||||
monkeypatch.setattr(
|
||||
mod,
|
||||
"simple_json_from_html_string",
|
||||
|
|
@ -234,7 +243,10 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.Mo
|
|||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
|
||||
monkeypatch.setattr(
|
||||
mod,
|
||||
"simple_json_from_html_string",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,227 @@
|
|||
import time
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def make_start_node(user_inputs, variables):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(),
|
||||
user_inputs=user_inputs,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
config = {
|
||||
"id": "start",
|
||||
"data": StartNodeData(title="Start", variables=variables).model_dump(),
|
||||
}
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
|
||||
return StartNode(
|
||||
id="start",
|
||||
config=config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="wf",
|
||||
graph_config={},
|
||||
user_id="u",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
|
||||
def test_json_object_valid_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
|
||||
assert result.outputs["profile"] == {"age": 20, "name": "Tom"}
|
||||
|
||||
|
||||
def test_json_object_invalid_json_string():
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
# Missing closing brace makes this invalid JSON
|
||||
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||
def test_json_object_valid_json_but_not_object(value):
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": value}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_does_not_match_schema():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match=r"JSON object for 'profile' does not match schema:"):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_missing_required_schema_field():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"JSON object for 'profile' does not match schema: 'name' is a required property"
|
||||
):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_required_variable_missing_from_inputs():
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_invalid_json_schema_string():
|
||||
variable = VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Bypass pydantic type validation on assignment to simulate an invalid JSON schema string
|
||||
variable.json_schema = "{invalid-json-schema"
|
||||
|
||||
variables = [variable]
|
||||
user_inputs = {"profile": '{"age": 20}'}
|
||||
|
||||
# Invalid json_schema string should be rejected during node data hydration
|
||||
with pytest.raises(PydanticValidationError):
|
||||
make_start_node(user_inputs, variables)
|
||||
|
||||
|
||||
def test_json_object_optional_variable_not_provided():
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
# Current implementation raises a validation error even when the variable is optional
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
4629
api/uv.lock
4629
api/uv.lock
File diff suppressed because it is too large
Load Diff
|
|
@ -233,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
|
|||
|
||||
# Database type, supported values are `postgresql` and `mysql`
|
||||
DB_TYPE=postgresql
|
||||
|
||||
# For MySQL, only `root` user is supported for now
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=db_postgres
|
||||
|
|
@ -1076,24 +1076,10 @@ MAX_TREE_DEPTH=50
|
|||
# ------------------------------
|
||||
# Environment Variables for database Service
|
||||
# ------------------------------
|
||||
|
||||
# The name of the default postgres user.
|
||||
POSTGRES_USER=${DB_USERNAME}
|
||||
# The password for the default postgres user.
|
||||
POSTGRES_PASSWORD=${DB_PASSWORD}
|
||||
# The name of the default postgres database.
|
||||
POSTGRES_DB=${DB_DATABASE}
|
||||
# Postgres data directory
|
||||
PGDATA=/var/lib/postgresql/data/pgdata
|
||||
|
||||
# MySQL Default Configuration
|
||||
# The name of the default mysql user.
|
||||
MYSQL_USERNAME=${DB_USERNAME}
|
||||
# The password for the default mysql user.
|
||||
MYSQL_PASSWORD=${DB_PASSWORD}
|
||||
# The name of the default mysql database.
|
||||
MYSQL_DATABASE=${DB_DATABASE}
|
||||
# MySQL data directory
|
||||
MYSQL_HOST_VOLUME=./volumes/mysql/data
|
||||
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.10.1
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -41,7 +41,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.10.1
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -78,7 +78,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.10.1
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -106,7 +106,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.10.1
|
||||
image: langgenius/dify-web:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -139,9 +139,9 @@ services:
|
|||
- postgresql
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dify}
|
||||
POSTGRES_USER: ${DB_USERNAME:-postgres}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${DB_DATABASE:-dify}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
command: >
|
||||
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
|
||||
|
|
@ -161,7 +161,7 @@ services:
|
|||
"-h",
|
||||
"db_postgres",
|
||||
"-U",
|
||||
"${PGUSER:-postgres}",
|
||||
"${DB_USERNAME:-postgres}",
|
||||
"-d",
|
||||
"${DB_DATABASE:-dify}",
|
||||
]
|
||||
|
|
@ -176,8 +176,8 @@ services:
|
|||
- mysql
|
||||
restart: always
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${DB_DATABASE:-dify}
|
||||
command: >
|
||||
--max_connections=1000
|
||||
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
|
||||
|
|
@ -193,7 +193,7 @@ services:
|
|||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"-p${MYSQL_PASSWORD:-difyai123456}",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ services:
|
|||
env_file:
|
||||
- ./middleware.env
|
||||
environment:
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dify}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${DB_DATABASE:-dify}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
command: >
|
||||
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
|
||||
|
|
@ -32,9 +32,9 @@ services:
|
|||
"-h",
|
||||
"db_postgres",
|
||||
"-U",
|
||||
"${PGUSER:-postgres}",
|
||||
"${DB_USERNAME:-postgres}",
|
||||
"-d",
|
||||
"${POSTGRES_DB:-dify}",
|
||||
"${DB_DATABASE:-dify}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
|
@ -48,8 +48,8 @@ services:
|
|||
env_file:
|
||||
- ./middleware.env
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${DB_DATABASE:-dify}
|
||||
command: >
|
||||
--max_connections=1000
|
||||
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
|
||||
|
|
@ -67,7 +67,7 @@ services:
|
|||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"-p${MYSQL_PASSWORD:-difyai123456}",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
|
|
|||
|
|
@ -455,13 +455,7 @@ x-shared-env: &shared-api-worker-env
|
|||
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
||||
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
|
||||
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
MYSQL_USERNAME: ${MYSQL_USERNAME:-${DB_USERNAME}}
|
||||
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-${DB_PASSWORD}}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-${DB_DATABASE}}
|
||||
MYSQL_HOST_VOLUME: ${MYSQL_HOST_VOLUME:-./volumes/mysql/data}
|
||||
SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox}
|
||||
SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release}
|
||||
|
|
@ -637,7 +631,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.10.1
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -676,7 +670,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.10.1
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -713,7 +707,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.10.1
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -741,7 +735,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.10.1
|
||||
image: langgenius/dify-web:1.10.1-fix.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -774,9 +768,9 @@ services:
|
|||
- postgresql
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dify}
|
||||
POSTGRES_USER: ${DB_USERNAME:-postgres}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${DB_DATABASE:-dify}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
command: >
|
||||
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
|
||||
|
|
@ -796,7 +790,7 @@ services:
|
|||
"-h",
|
||||
"db_postgres",
|
||||
"-U",
|
||||
"${PGUSER:-postgres}",
|
||||
"${DB_USERNAME:-postgres}",
|
||||
"-d",
|
||||
"${DB_DATABASE:-dify}",
|
||||
]
|
||||
|
|
@ -811,8 +805,8 @@ services:
|
|||
- mysql
|
||||
restart: always
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${DB_DATABASE:-dify}
|
||||
command: >
|
||||
--max_connections=1000
|
||||
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
|
||||
|
|
@ -828,7 +822,7 @@ services:
|
|||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"-p${MYSQL_PASSWORD:-difyai123456}",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# Database Configuration
|
||||
# Database type, supported values are `postgresql` and `mysql`
|
||||
DB_TYPE=postgresql
|
||||
# For MySQL, only `root` user is supported for now
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=db_postgres
|
||||
|
|
@ -11,11 +12,6 @@ DB_PORT=5432
|
|||
DB_DATABASE=dify
|
||||
|
||||
# PostgreSQL Configuration
|
||||
POSTGRES_USER=${DB_USERNAME}
|
||||
# The password for the default postgres user.
|
||||
POSTGRES_PASSWORD=${DB_PASSWORD}
|
||||
# The name of the default postgres database.
|
||||
POSTGRES_DB=${DB_DATABASE}
|
||||
# postgres data directory
|
||||
PGDATA=/var/lib/postgresql/data/pgdata
|
||||
PGDATA_HOST_VOLUME=./volumes/db/data
|
||||
|
|
@ -65,11 +61,6 @@ POSTGRES_STATEMENT_TIMEOUT=0
|
|||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
|
||||
|
||||
# MySQL Configuration
|
||||
MYSQL_USERNAME=${DB_USERNAME}
|
||||
# MySQL password
|
||||
MYSQL_PASSWORD=${DB_PASSWORD}
|
||||
# MySQL database name
|
||||
MYSQL_DATABASE=${DB_DATABASE}
|
||||
# MySQL data directory host volume
|
||||
MYSQL_HOST_VOLUME=./volumes/mysql/data
|
||||
|
||||
|
|
|
|||
|
|
@ -61,13 +61,13 @@ if $web_modified; then
|
|||
lint-staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check"
|
||||
if ! pnpm run type-check; then
|
||||
echo "Type check failed. Please run 'pnpm run type-check' to fix the errors."
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
if ! pnpm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No staged TypeScript changes detected, skipping type-check"
|
||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||
fi
|
||||
|
||||
echo "Running unit tests check"
|
||||
|
|
|
|||
|
|
@ -251,6 +251,7 @@ const AgentTools: FC = () => {
|
|||
{!item.notAuthor && (
|
||||
<Tooltip
|
||||
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
|
||||
needsDelay={false}
|
||||
>
|
||||
<div className='cursor-pointer rounded-md p-1 hover:bg-black/5' onClick={() => {
|
||||
setCurrentTool(item)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
import React from 'react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import nock from 'nock'
|
||||
import GithubStar from './index'
|
||||
|
||||
const GITHUB_HOST = 'https://api.github.com'
|
||||
const GITHUB_PATH = '/repos/langgenius/dify'
|
||||
|
||||
const renderWithQueryClient = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<GithubStar className='test-class' />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
const mockGithubStar = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST).get(GITHUB_PATH).delay(delayMs).reply(status, body)
|
||||
}
|
||||
|
||||
describe('GithubStar', () => {
|
||||
beforeEach(() => {
|
||||
nock.cleanAll()
|
||||
})
|
||||
|
||||
// Shows fetched star count when request succeeds
|
||||
it('should render fetched star count', async () => {
|
||||
mockGithubStar(200, { stargazers_count: 123456 })
|
||||
|
||||
renderWithQueryClient()
|
||||
|
||||
expect(await screen.findByText('123,456')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Falls back to default star count when request fails
|
||||
it('should render default star count on error', async () => {
|
||||
mockGithubStar(500, {})
|
||||
|
||||
renderWithQueryClient()
|
||||
|
||||
expect(await screen.findByText('110,918')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Renders loader while fetching data
|
||||
it('should show loader while fetching', async () => {
|
||||
mockGithubStar(200, { stargazers_count: 222222 }, 50)
|
||||
|
||||
const { container } = renderWithQueryClient()
|
||||
|
||||
expect(container.querySelector('.animate-spin')).toBeInTheDocument()
|
||||
await waitFor(() => expect(screen.getByText('222,222')).toBeInTheDocument())
|
||||
})
|
||||
})
|
||||
|
|
@ -76,7 +76,7 @@ const ProviderCard: FC<Props> = ({
|
|||
className='grow'
|
||||
variant='secondary'
|
||||
>
|
||||
<a href={`${getPluginLinkInMarketplace(payload)}?language=${locale}${theme ? `&theme=${theme}` : ''}`} target='_blank' className='flex items-center gap-0.5'>
|
||||
<a href={getPluginLinkInMarketplace(payload, { language: locale, theme })} target='_blank' className='flex items-center gap-0.5'>
|
||||
{t('plugin.detailPanel.operation.detail')}
|
||||
<RiArrowRightUpLine className='h-4 w-4' />
|
||||
</a>
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ export default combine(
|
|||
// orignal ts/no-var-requires
|
||||
'ts/no-require-imports': 'off',
|
||||
'no-console': 'off',
|
||||
'react-hooks/exhaustive-deps': 'warn',
|
||||
'react/display-name': 'off',
|
||||
'array-callback-return': ['error', {
|
||||
allowImplicit: false,
|
||||
|
|
@ -257,4 +256,9 @@ export default combine(
|
|||
},
|
||||
},
|
||||
oxlint.configs['flat/recommended'],
|
||||
{
|
||||
rules: {
|
||||
'react-hooks/exhaustive-deps': 'warn',
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -197,258 +197,6 @@ const translation = {
|
|||
},
|
||||
contentEnableLabel: 'مدیریت محتوا فعال شده است',
|
||||
},
|
||||
generate: {
|
||||
title: 'تولید کننده دستورالعمل',
|
||||
description: 'تولید کننده دستورالعمل از مدل تنظیم شده برای بهینه سازی دستورالعملها برای کیفیت بالاتر و ساختار بهتر استفاده میکند. لطفاً دستورالعملهای واضح و دقیقی بنویسید.',
|
||||
tryIt: 'امتحان کنید',
|
||||
instruction: 'دستورالعملها',
|
||||
instructionPlaceHolder: 'دستورالعملهای واضح و خاصی بنویسید.',
|
||||
generate: 'تولید',
|
||||
resTitle: 'دستورالعمل تولید شده',
|
||||
noDataLine1: 'موارد استفاده خود را در سمت چپ توصیف کنید،',
|
||||
noDataLine2: 'پیشنمایش ارکستراسیون در اینجا نشان داده خواهد شد.',
|
||||
apply: 'اعمال',
|
||||
loading: 'در حال ارکستراسیون برنامه برای شما...',
|
||||
overwriteTitle: 'آیا تنظیمات موجود را لغو میکنید؟',
|
||||
overwriteMessage: 'اعمال این دستورالعمل تنظیمات موجود را لغو خواهد کرد.',
|
||||
template: {
|
||||
pythonDebugger: {
|
||||
name: 'اشکالزدای پایتون',
|
||||
instruction: 'یک بات که میتواند بر اساس دستورالعمل شما کد تولید و اشکالزدایی کند',
|
||||
},
|
||||
translation: {
|
||||
name: 'ترجمه',
|
||||
instruction: 'یک مترجم که میتواند چندین زبان را ترجمه کند',
|
||||
},
|
||||
professionalAnalyst: {
|
||||
name: 'تحلیلگر حرفهای',
|
||||
instruction: 'استخراج بینشها، شناسایی ریسک و خلاصهسازی اطلاعات کلیدی از گزارشهای طولانی به یک یادداشت کوتاه',
|
||||
},
|
||||
excelFormulaExpert: {
|
||||
name: 'کارشناس فرمول اکسل',
|
||||
instruction: 'یک چتبات که میتواند به کاربران مبتدی کمک کند فرمولهای اکسل را بر اساس دستورالعملهای کاربر درک، استفاده و ایجاد کنند',
|
||||
},
|
||||
travelPlanning: {
|
||||
name: 'برنامهریزی سفر',
|
||||
instruction: 'دستیار برنامهریزی سفر یک ابزار هوشمند است که به کاربران کمک میکند سفرهای خود را به راحتی برنامهریزی کنند',
|
||||
},
|
||||
SQLSorcerer: {
|
||||
name: 'جادوگر SQL',
|
||||
instruction: 'تبدیل زبان روزمره به پرس و جوهای SQL',
|
||||
},
|
||||
GitGud: {
|
||||
name: 'Git gud',
|
||||
instruction: 'تولید دستورات مناسب Git بر اساس اقدامات توصیف شده توسط کاربر در کنترل نسخه',
|
||||
},
|
||||
meetingTakeaways: {
|
||||
name: 'نتایج جلسات',
|
||||
instruction: 'خلاصهسازی جلسات به صورت مختصر شامل موضوعات بحث، نکات کلیدی و موارد اقدام',
|
||||
},
|
||||
writingsPolisher: {
|
||||
name: 'پولیشگر نوشتهها',
|
||||
instruction: 'استفاده از تکنیکهای ویرایش پیشرفته برای بهبود نوشتههای شما',
|
||||
},
|
||||
},
|
||||
},
|
||||
resetConfig: {
|
||||
title: 'بازنشانی تأیید میشود؟',
|
||||
message: 'بازنشانی تغییرات را لغو کرده و تنظیمات منتشر شده آخر را بازیابی میکند.',
|
||||
},
|
||||
errorMessage: {
|
||||
nameOfKeyRequired: 'نام کلید: {{key}} مورد نیاز است',
|
||||
valueOfVarRequired: 'مقدار {{key}} نمیتواند خالی باشد',
|
||||
queryRequired: 'متن درخواست مورد نیاز است.',
|
||||
waitForResponse: 'لطفاً منتظر پاسخ به پیام قبلی بمانید.',
|
||||
waitForBatchResponse: 'لطفاً منتظر پاسخ به کار دستهای بمانید.',
|
||||
notSelectModel: 'لطفاً یک مدل را انتخاب کنید',
|
||||
waitForImgUpload: 'لطفاً منتظر بارگذاری تصویر بمانید',
|
||||
},
|
||||
chatSubTitle: 'دستورالعملها',
|
||||
completionSubTitle: 'پیشوند پرس و جو',
|
||||
promptTip: 'دستورالعملها و محدودیتها پاسخهای AI را هدایت میکنند. متغیرهایی مانند {{input}} را درج کنید. این دستورالعمل برای کاربران قابل مشاهده نخواهد بود.',
|
||||
formattingChangedTitle: 'قالببندی تغییر کرد',
|
||||
formattingChangedText: 'تغییر قالببندی منطقه اشکالزدایی را بازنشانی خواهد کرد، آیا مطمئن هستید؟',
|
||||
variableTitle: 'متغیرها',
|
||||
variableTip: 'کاربران متغیرها را در فرم پر میکنند و به طور خودکار متغیرها را در دستورالعملها جایگزین میکنند.',
|
||||
notSetVar: 'متغیرها به کاربران اجازه میدهند که کلمات پرس و جو یا جملات ابتدایی را هنگام پر کردن فرم معرفی کنند. شما میتوانید سعی کنید "{{input}}" را در کلمات پرس و جو وارد کنید.',
|
||||
autoAddVar: 'متغیرهای تعریف نشدهای که در پیشپرسش ذکر شدهاند، آیا میخواهید آنها را به فرم ورودی کاربر اضافه کنید؟',
|
||||
variableTable: {
|
||||
key: 'کلید متغیر',
|
||||
name: 'نام فیلد ورودی کاربر',
|
||||
optional: 'اختیاری',
|
||||
type: 'نوع ورودی',
|
||||
action: 'اقدامات',
|
||||
typeString: 'رشته',
|
||||
typeSelect: 'انتخاب',
|
||||
},
|
||||
varKeyError: {
|
||||
canNoBeEmpty: '{{key}} مطلوب',
|
||||
tooLong: '{{key}} طولانی است. نمیتواند بیش از 30 کاراکتر باشد',
|
||||
notValid: '{{key}} نامعتبر است. فقط میتواند شامل حروف، اعداد و زیرخط باشد',
|
||||
notStartWithNumber: '{{key}} نمیتواند با عدد شروع شود',
|
||||
keyAlreadyExists: '{{key}} از قبل وجود دارد',
|
||||
},
|
||||
otherError: {
|
||||
promptNoBeEmpty: 'پرس و جو نمیتواند خالی باشد',
|
||||
historyNoBeEmpty: 'تاریخچه مکالمه باید در پرس و جو تنظیم شود',
|
||||
queryNoBeEmpty: 'پرس و جو باید در پرس و جو تنظیم شود',
|
||||
},
|
||||
variableConfig: {
|
||||
'addModalTitle': 'افزودن فیلد ورودی',
|
||||
'editModalTitle': 'ویرایش فیلد ورودی',
|
||||
'description': 'تنظیم برای متغیر {{varName}}',
|
||||
'fieldType': 'نوع فیلد',
|
||||
'string': 'متن کوتاه',
|
||||
'text-input': 'متن کوتاه',
|
||||
'paragraph': 'پاراگراف',
|
||||
'select': 'انتخاب',
|
||||
'number': 'عدد',
|
||||
'notSet': 'تنظیم نشده، سعی کنید {{input}} را در پرس و جو وارد کنید',
|
||||
'stringTitle': 'گزینههای جعبه متن فرم',
|
||||
'maxLength': 'حداکثر طول',
|
||||
'options': 'گزینهها',
|
||||
'addOption': 'افزودن گزینه',
|
||||
'apiBasedVar': 'متغیر مبتنی بر API',
|
||||
'varName': 'نام متغیر',
|
||||
'labelName': 'نام برچسب',
|
||||
'inputPlaceholder': 'لطفاً وارد کنید',
|
||||
'content': 'محتوا',
|
||||
'required': 'مورد نیاز',
|
||||
'hide': 'مخفی کردن',
|
||||
'errorMsg': {
|
||||
labelNameRequired: 'نام برچسب مورد نیاز است',
|
||||
varNameCanBeRepeat: 'نام متغیر نمیتواند تکراری باشد',
|
||||
atLeastOneOption: 'حداقل یک گزینه مورد نیاز است',
|
||||
optionRepeat: 'گزینههای تکراری وجود دارد',
|
||||
},
|
||||
},
|
||||
vision: {
|
||||
name: 'بینایی',
|
||||
description: 'فعال کردن بینایی به مدل اجازه میدهد تصاویر را دریافت کند و به سوالات مربوط به آنها پاسخ دهد.',
|
||||
settings: 'تنظیمات',
|
||||
visionSettings: {
|
||||
title: 'تنظیمات بینایی',
|
||||
resolution: 'وضوح',
|
||||
resolutionTooltip: `وضوح پایین به مدل اجازه میدهد نسخه 512x512 کموضوح تصویر را دریافت کند و تصویر را با بودجه 65 توکن نمایش دهد. این به API اجازه میدهد پاسخهای سریعتری بدهد و توکنهای ورودی کمتری برای موارد استفاده که نیاز به جزئیات بالا ندارند مصرف کند.
|
||||
\n
|
||||
وضوح بالا ابتدا به مدل اجازه میدهد تصویر کموضوح را ببیند و سپس قطعات جزئیات تصویر ورودی را به عنوان مربعهای 512px ایجاد کند. هر کدام از قطعات جزئیات از بودجه توکن دو برابر استفاده میکنند که در مجموع 129 توکن است.`,
|
||||
high: 'بالا',
|
||||
low: 'پایین',
|
||||
uploadMethod: 'روش بارگذاری',
|
||||
both: 'هر دو',
|
||||
localUpload: 'بارگذاری محلی',
|
||||
url: 'URL',
|
||||
uploadLimit: 'محدودیت بارگذاری',
|
||||
},
|
||||
},
|
||||
voice: {
|
||||
name: 'صدا',
|
||||
defaultDisplay: 'صدا پیش فرض',
|
||||
description: 'تنظیمات تبدیل متن به گفتار',
|
||||
settings: 'تنظیمات',
|
||||
voiceSettings: {
|
||||
title: 'تنظیمات صدا',
|
||||
language: 'زبان',
|
||||
resolutionTooltip: 'پشتیبانی از زبان صدای تبدیل متن به گفتار.',
|
||||
voice: 'صدا',
|
||||
autoPlay: 'پخش خودکار',
|
||||
autoPlayEnabled: 'روشن کردن',
|
||||
autoPlayDisabled: 'خاموش کردن',
|
||||
},
|
||||
},
|
||||
openingStatement: {
|
||||
title: 'شروع مکالمه',
|
||||
add: 'افزودن',
|
||||
writeOpener: 'نوشتن آغازگر',
|
||||
placeholder: 'پیام آغازگر خود را اینجا بنویسید، میتوانید از متغیرها استفاده کنید، سعی کنید {{variable}} را تایپ کنید.',
|
||||
openingQuestion: 'سوالات آغازین',
|
||||
openingQuestionPlaceholder: 'میتوانید از متغیرها استفاده کنید، سعی کنید {{variable}} را تایپ کنید.',
|
||||
noDataPlaceHolder: 'شروع مکالمه با کاربر میتواند به AI کمک کند تا ارتباط نزدیکتری با آنها برقرار کند.',
|
||||
varTip: 'میتوانید از متغیرها استفاده کنید، سعی کنید {{variable}} را تایپ کنید',
|
||||
tooShort: 'حداقل 20 کلمه از پرسش اولیه برای تولید نظرات آغازین مکالمه مورد نیاز است.',
|
||||
notIncludeKey: 'پرسش اولیه شامل متغیر: {{key}} نمیشود. لطفاً آن را به پرسش اولیه اضافه کنید.',
|
||||
},
|
||||
modelConfig: {
|
||||
model: 'مدل',
|
||||
setTone: 'تنظیم لحن پاسخها',
|
||||
title: 'مدل و پارامترها',
|
||||
modeType: {
|
||||
chat: 'چت',
|
||||
completion: 'تکمیل',
|
||||
},
|
||||
},
|
||||
inputs: {
|
||||
title: 'اشکالزدایی و پیشنمایش',
|
||||
noPrompt: 'سعی کنید پرسشهایی را در ورودی پیشپرسش بنویسید',
|
||||
userInputField: 'فیلد ورودی کاربر',
|
||||
noVar: 'مقدار متغیر را پر کنید، که به طور خودکار در کلمات پرس و جو در هر بار شروع یک جلسه جدید جایگزین میشود.',
|
||||
chatVarTip: 'مقدار متغیر را پر کنید، که به طور خودکار در کلمات پرس و جو در هر بار شروع یک جلسه جدید جایگزین میشود',
|
||||
completionVarTip: 'مقدار متغیر را پر کنید، که به طور خودکار در کلمات پرس و جو در هر بار ارسال سوال جایگزین میشود.',
|
||||
previewTitle: 'پیشنمایش پرس و جو',
|
||||
queryTitle: 'محتوای پرس و جو',
|
||||
queryPlaceholder: 'لطفاً متن درخواست را وارد کنید.',
|
||||
run: 'اجرا',
|
||||
},
|
||||
result: 'متن خروجی',
|
||||
datasetConfig: {
|
||||
settingTitle: 'تنظیمات بازیابی',
|
||||
knowledgeTip: 'روی دکمه "+" کلیک کنید تا دانش اضافه شود',
|
||||
retrieveOneWay: {
|
||||
title: 'بازیابی N به 1',
|
||||
description: 'بر اساس نیت کاربر و توصیفات دانش، عامل بهترین دانش را برای پرس و جو به طور خودکار انتخاب میکند. بهترین برای برنامههایی با دانش محدود و مشخص.',
|
||||
},
|
||||
retrieveMultiWay: {
|
||||
title: 'بازیابی چند مسیره',
|
||||
description: 'بر اساس نیت کاربر، از تمام دانش پرس و جو میکند، متنهای مرتبط از منابع چندگانه بازیابی میکند و بهترین نتایج مطابقت با پرس و جوی کاربر را پس از مرتبسازی مجدد انتخاب میکند.',
|
||||
},
|
||||
rerankModelRequired: 'مدل مرتبسازی مجدد مورد نیاز است',
|
||||
params: 'پارامترها',
|
||||
top_k: 'Top K',
|
||||
top_kTip: 'برای فیلتر کردن تکههایی که بیشترین شباهت به سوالات کاربر دارند استفاده میشود. سیستم همچنین به طور دینامیک مقدار Top K را بر اساس max_tokens مدل انتخاب شده تنظیم میکند.',
|
||||
score_threshold: 'آستانه نمره',
|
||||
score_thresholdTip: 'برای تنظیم آستانه شباهت برای فیلتر کردن تکهها استفاده میشود.',
|
||||
retrieveChangeTip: 'تغییر حالت شاخص و حالت بازیابی ممکن است بر برنامههای مرتبط با این دانش تأثیر بگذارد.',
|
||||
},
|
||||
debugAsSingleModel: 'اشکالزدایی به عنوان مدل تک',
|
||||
debugAsMultipleModel: 'اشکالزدایی به عنوان مدل چندگانه',
|
||||
duplicateModel: 'تکراری',
|
||||
publishAs: 'انتشار به عنوان',
|
||||
assistantType: {
|
||||
name: 'نوع دستیار',
|
||||
chatAssistant: {
|
||||
name: 'دستیار پایه',
|
||||
description: 'ساخت دستیار مبتنی بر چت با استفاده از مدل زبان بزرگ',
|
||||
},
|
||||
agentAssistant: {
|
||||
name: 'دستیار عامل',
|
||||
description: 'ساخت یک عامل هوشمند که میتواند ابزارها را به طور خودکار برای تکمیل وظایف انتخاب کند',
|
||||
},
|
||||
},
|
||||
agent: {
|
||||
agentMode: 'حالت عامل',
|
||||
agentModeDes: 'تنظیم نوع حالت استنتاج برای عامل',
|
||||
agentModeType: {
|
||||
ReACT: 'ReAct',
|
||||
functionCall: 'فراخوانی تابع',
|
||||
},
|
||||
setting: {
|
||||
name: 'تنظیمات عامل',
|
||||
description: 'تنظیمات دستیار عامل به شما اجازه میدهد حالت عامل و ویژگیهای پیشرفته مانند پرسشهای ساخته شده را تنظیم کنید، فقط در نوع عامل موجود است.',
|
||||
maximumIterations: {
|
||||
name: 'حداکثر تکرارها',
|
||||
description: 'محدود کردن تعداد تکرارهایی که دستیار عامل میتواند اجرا کند',
|
||||
},
|
||||
},
|
||||
buildInPrompt: 'پرسشهای ساخته شده',
|
||||
firstPrompt: 'اولین پرسش',
|
||||
nextIteration: 'تکرار بعدی',
|
||||
promptPlaceholder: 'پرسش خود را اینجا بنویسید',
|
||||
tools: {
|
||||
name: 'ابزارها',
|
||||
description: 'استفاده از ابزارها میتواند قابلیتهای LLM را گسترش دهد، مانند جستجو در اینترنت یا انجام محاسبات علمی',
|
||||
enabled: 'فعال',
|
||||
},
|
||||
},
|
||||
fileUpload: {
|
||||
title: 'آپلود فایل',
|
||||
description: 'جعبه ورودی چت امکان آپلود تصاویر، اسناد و سایر فایلها را فراهم میکند.',
|
||||
|
|
@ -536,13 +284,10 @@ const translation = {
|
|||
resTitle: 'اعلان تولید شده',
|
||||
overwriteTitle: 'پیکربندی موجود را لغو کنید؟',
|
||||
generate: 'تولید',
|
||||
noDataLine1: 'مورد استفاده خود را در سمت چپ شرح دهید،',
|
||||
apply: 'درخواست',
|
||||
instruction: 'دستورالعمل',
|
||||
overwriteMessage: 'اعمال این اعلان پیکربندی موجود را لغو می کند.',
|
||||
instructionPlaceHolder: 'دستورالعمل های واضح و مشخص بنویسید.',
|
||||
tryIt: 'آن را امتحان کنید',
|
||||
noDataLine2: 'پیش نمایش ارکستراسیون در اینجا نشان داده می شود.',
|
||||
loading: 'هماهنگ کردن برنامه برای شما...',
|
||||
description: 'Prompt Generator از مدل پیکربندی شده برای بهینه سازی درخواست ها برای کیفیت بالاتر و ساختار بهتر استفاده می کند. لطفا دستورالعمل های واضح و دقیق بنویسید.',
|
||||
press: 'فشار',
|
||||
|
|
|
|||
|
|
@ -393,6 +393,7 @@ const translation = {
|
|||
writeOpener: 'Scrieți deschizătorul',
|
||||
placeholder: 'Scrieți aici mesajul de deschidere, puteți utiliza variabile, încercați să tastați {{variable}}.',
|
||||
openingQuestion: 'Întrebări de deschidere',
|
||||
openingQuestionPlaceholder: 'Puteți utiliza variabile, încercați să tastați {{variable}}.',
|
||||
noDataPlaceHolder:
|
||||
'Începerea conversației cu utilizatorul poate ajuta AI să stabilească o conexiune mai strânsă cu ei în aplicațiile conversaționale.',
|
||||
varTip: 'Puteți utiliza variabile, încercați să tastați {{variable}}',
|
||||
|
|
|
|||
|
|
@ -479,87 +479,6 @@ const translation = {
|
|||
loadBalancingLeastKeyWarning: 'Za omogočanje uravnoteženja obremenitev morata biti omogočena vsaj 2 ključa.',
|
||||
loadBalancingInfo: 'Privzeto uravnoteženje obremenitev uporablja strategijo Round-robin. Če se sproži omejitev hitrosti, se uporabi 1-minutno obdobje ohlajanja.',
|
||||
upgradeForLoadBalancing: 'Nadgradite svoj načrt, da omogočite uravnoteženje obremenitev.',
|
||||
dataSource: {
|
||||
notion: {
|
||||
selector: {
|
||||
},
|
||||
},
|
||||
website: {
|
||||
},
|
||||
},
|
||||
plugin: {
|
||||
serpapi: {
|
||||
},
|
||||
},
|
||||
apiBasedExtension: {
|
||||
selector: {
|
||||
},
|
||||
modal: {
|
||||
name: {
|
||||
},
|
||||
apiEndpoint: {
|
||||
},
|
||||
apiKey: {
|
||||
},
|
||||
},
|
||||
},
|
||||
about: {
|
||||
},
|
||||
appMenus: {
|
||||
},
|
||||
environment: {
|
||||
},
|
||||
appModes: {
|
||||
},
|
||||
datasetMenus: {
|
||||
},
|
||||
voiceInput: {
|
||||
},
|
||||
modelName: {
|
||||
'gpt-3.5-turbo': 'GPT-3.5-Turbo',
|
||||
'gpt-3.5-turbo-16k': 'GPT-3.5-Turbo-16K',
|
||||
'gpt-4': 'GPT-4',
|
||||
'gpt-4-32k': 'GPT-4-32K',
|
||||
'text-davinci-003': 'Text-Davinci-003',
|
||||
'text-embedding-ada-002': 'Text-Embedding-Ada-002',
|
||||
'whisper-1': 'Whisper-1',
|
||||
'claude-instant-1': 'Claude-Instant',
|
||||
'claude-2': 'Claude-2',
|
||||
},
|
||||
chat: {
|
||||
citation: {
|
||||
},
|
||||
},
|
||||
promptEditor: {
|
||||
context: {
|
||||
item: {
|
||||
},
|
||||
modal: {
|
||||
},
|
||||
},
|
||||
history: {
|
||||
item: {
|
||||
},
|
||||
modal: {
|
||||
},
|
||||
},
|
||||
variable: {
|
||||
item: {
|
||||
},
|
||||
outputToolDisabledItem: {
|
||||
},
|
||||
modal: {
|
||||
},
|
||||
},
|
||||
query: {
|
||||
item: {
|
||||
},
|
||||
},
|
||||
},
|
||||
imageUploader: {
|
||||
},
|
||||
tag: {
|
||||
},
|
||||
discoverMore: 'Odkrijte več v',
|
||||
installProvider: 'Namestitev ponudnikov modelov',
|
||||
emptyProviderTitle: 'Ponudnik modelov ni nastavljen',
|
||||
|
|
|
|||
|
|
@ -348,7 +348,6 @@ const translation = {
|
|||
'description': 'Değişken ayarı {{varName}}',
|
||||
'fieldType': 'Alan türü',
|
||||
'string': 'Kısa Metin',
|
||||
'textInput': 'Kısa Metin',
|
||||
'paragraph': 'Paragraf',
|
||||
'select': 'Seçim',
|
||||
'number': 'Numara',
|
||||
|
|
@ -364,7 +363,6 @@ const translation = {
|
|||
'content': 'İçerik',
|
||||
'required': 'Gerekli',
|
||||
'errorMsg': {
|
||||
varNameRequired: 'Değişken adı gereklidir',
|
||||
labelNameRequired: 'Etiket adı gereklidir',
|
||||
varNameCanBeRepeat: 'Değişken adı tekrar edemez',
|
||||
atLeastOneOption: 'En az bir seçenek gereklidir',
|
||||
|
|
|
|||
|
|
@ -596,7 +596,6 @@ const translation = {
|
|||
'authorizationType': 'Yetkilendirme Türü',
|
||||
'no-auth': 'Yok',
|
||||
'api-key': 'API Anahtarı',
|
||||
'authType': 'Yetki Türü',
|
||||
'basic': 'Temel',
|
||||
'bearer': 'Bearer',
|
||||
'custom': 'Özel',
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
"lint:quiet": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet",
|
||||
"lint:complexity": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --rule 'complexity: [error, {max: 15}]' --quiet",
|
||||
"type-check": "tsc --noEmit",
|
||||
"type-check:tsgo": "tsgo --noEmit",
|
||||
"prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky",
|
||||
"gen-icons": "node ./app/components/base/icons/script.mjs",
|
||||
"uglify-embed": "node ./bin/uglify-embed",
|
||||
|
|
@ -110,9 +111,9 @@
|
|||
"pinyin-pro": "^3.27.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"qs": "^6.14.0",
|
||||
"react": "19.1.1",
|
||||
"react": "19.2.1",
|
||||
"react-18-input-autosize": "^3.0.0",
|
||||
"react-dom": "19.1.1",
|
||||
"react-dom": "19.2.1",
|
||||
"react-easy-crop": "^5.5.3",
|
||||
"react-hook-form": "^7.65.0",
|
||||
"react-hotkeys-hook": "^4.6.2",
|
||||
|
|
@ -153,9 +154,9 @@
|
|||
"@happy-dom/jest-environment": "^20.0.8",
|
||||
"@mdx-js/loader": "^3.1.1",
|
||||
"@mdx-js/react": "^3.1.1",
|
||||
"@next/bundle-analyzer": "15.5.4",
|
||||
"@next/eslint-plugin-next": "15.5.4",
|
||||
"@next/mdx": "15.5.4",
|
||||
"@next/bundle-analyzer": "15.5.7",
|
||||
"@next/eslint-plugin-next": "15.5.7",
|
||||
"@next/mdx": "15.5.7",
|
||||
"@rgrove/parse-xml": "^4.2.0",
|
||||
"@storybook/addon-docs": "9.1.13",
|
||||
"@storybook/addon-links": "9.1.13",
|
||||
|
|
@ -173,8 +174,8 @@
|
|||
"@types/negotiator": "^0.6.4",
|
||||
"@types/node": "18.15.0",
|
||||
"@types/qs": "^6.14.0",
|
||||
"@types/react": "~19.1.17",
|
||||
"@types/react-dom": "~19.1.11",
|
||||
"@types/react": "~19.2.7",
|
||||
"@types/react-dom": "~19.2.3",
|
||||
"@types/react-slider": "^1.3.6",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@types/react-window": "^1.8.8",
|
||||
|
|
@ -200,18 +201,20 @@
|
|||
"lint-staged": "^15.5.2",
|
||||
"lodash": "^4.17.21",
|
||||
"magicast": "^0.3.5",
|
||||
"nock": "^14.0.10",
|
||||
"postcss": "^8.5.6",
|
||||
"react-scan": "^0.4.3",
|
||||
"sass": "^1.93.2",
|
||||
"storybook": "9.1.13",
|
||||
"tailwindcss": "^3.4.18",
|
||||
"@typescript/native-preview": "^7.0.0-dev",
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.9.3",
|
||||
"uglify-js": "^3.19.3"
|
||||
},
|
||||
"resolutions": {
|
||||
"@types/react": "~19.1.17",
|
||||
"@types/react-dom": "~19.1.11",
|
||||
"@types/react": "~19.2.7",
|
||||
"@types/react-dom": "~19.2.3",
|
||||
"string-width": "~4.2.3",
|
||||
"@eslint/plugin-kit": "~0.3",
|
||||
"canvas": "^3.2.0",
|
||||
|
|
@ -282,4 +285,4 @@
|
|||
"sharp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
1476
web/pnpm-lock.yaml
1476
web/pnpm-lock.yaml
File diff suppressed because it is too large
Load Diff
|
|
@ -144,7 +144,7 @@ function requiredWebSSOLogin(message?: string, code?: number) {
|
|||
params.append('message', message)
|
||||
if (code)
|
||||
params.append('code', String(code))
|
||||
globalThis.location.href = `${globalThis.location.origin}${basePath}/${WBB_APP_LOGIN_PATH}?${params.toString()}`
|
||||
globalThis.location.href = `${globalThis.location.origin}${basePath}${WBB_APP_LOGIN_PATH}?${params.toString()}`
|
||||
}
|
||||
|
||||
export function format(text: string) {
|
||||
|
|
|
|||
|
|
@ -612,12 +612,11 @@ export const usePluginTaskList = (category?: PluginCategoryEnum | string) => {
|
|||
const taskAllFailed = lastData?.tasks.every(task => task.status === TaskStatus.failed)
|
||||
if (taskDone && lastData?.tasks.length && !taskAllFailed)
|
||||
refreshPluginList(category ? { category } as any : undefined, !category)
|
||||
}, [initialized, isRefetching, data, category, refreshPluginList])
|
||||
}, [isRefetching])
|
||||
|
||||
useEffect(() => {
|
||||
if (isFetched && !initialized)
|
||||
setInitialized(true)
|
||||
}, [isFetched, initialized])
|
||||
setInitialized(true)
|
||||
}, [])
|
||||
|
||||
const handleRefetch = useCallback(() => {
|
||||
refetch()
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ export const useInvalidLastRun = (flowType: FlowType, flowId: string, nodeId: st
|
|||
|
||||
// Rerun workflow or change the version of workflow
|
||||
export const useInvalidAllLastRun = (flowType?: FlowType, flowId?: string) => {
|
||||
return useInvalid([NAME_SPACE, flowType, 'last-run', flowId])
|
||||
return useInvalid([...useLastRunKey, flowType, flowId])
|
||||
}
|
||||
|
||||
export const useConversationVarValues = (flowType?: FlowType, flowId?: string) => {
|
||||
|
|
|
|||
|
|
@ -202,6 +202,16 @@ Reserve snapshots for static, deterministic fragments (icons, badges, layout chr
|
|||
|
||||
**Note**: Dify is a desktop application. **No need for** responsive/mobile testing.
|
||||
|
||||
### 12. Mock API
|
||||
|
||||
Use Nock to mock API calls. Example:
|
||||
|
||||
```ts
|
||||
const mockGithubStar = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST).get(GITHUB_PATH).delay(delayMs).reply(status, body)
|
||||
}
|
||||
```
|
||||
|
||||
## Code Style
|
||||
|
||||
### Example Structure
|
||||
|
|
|
|||
Loading…
Reference in New Issue