mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
Merge branch 'main' into feat/hitl
This commit is contained in:
commit
b2ceb41dd6
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@ -9,6 +9,9 @@
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
# Agents
|
||||
/.agents/skills/ @hyoban
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
@ -21,6 +24,10 @@
|
||||
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
/api/controllers/mcp/ @Nov1c444
|
||||
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
|
||||
# Backend - Tests
|
||||
/api/tests/ @laipz8200 @QuantumGhost
|
||||
|
||||
/api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
@ -231,6 +238,9 @@
|
||||
# Frontend - Base Components
|
||||
/web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Base Components Tests
|
||||
/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
|
||||
203
api/commands.py
203
api/commands.py
@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool):
|
||||
all_ids_in_tables = []
|
||||
for ids_table in ids_tables:
|
||||
query = ""
|
||||
if ids_table["type"] == "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
elif ids_table["type"] == "text":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
|
||||
fg="white",
|
||||
c = ids_table["column"]
|
||||
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
case "text":
|
||||
t = ids_table["table"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
elif ids_table["type"] == "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case "json":
|
||||
click.echo(
|
||||
click.style(
|
||||
(
|
||||
f"- Listing file-id-like JSON string in column {ids_table['column']} "
|
||||
f"in table {ids_table['table']}"
|
||||
),
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
query = (
|
||||
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
|
||||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
case _:
|
||||
pass
|
||||
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
|
||||
|
||||
except Exception as e:
|
||||
@ -1737,59 +1741,18 @@ def file_usage(
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
if ids_table["type"] == "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
elif ids_table["type"] in ("text", "json"):
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
match ids_table["type"]:
|
||||
case "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
@ -1812,6 +1775,50 @@ def file_usage(
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
case "text" | "json":
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
case _:
|
||||
pass
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -16,9 +17,11 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
Annotation,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
AnnotationList,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Annotation,
|
||||
AnnotationList,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -201,33 +213,33 @@ class AnnotationApi(Resource):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
console_ns.models[AnnotationExportList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
|
||||
|
||||
# Create response with secure headers for CSV export
|
||||
response = make_response(response_data, 200)
|
||||
@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
console_ns.models[AnnotationHitHistoryList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
)
|
||||
response = {
|
||||
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
"has_more": len(annotation_hit_history_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response
|
||||
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
|
||||
annotation_hit_history_list, from_attributes=True
|
||||
)
|
||||
response = AnnotationHitHistoryList(
|
||||
data=history_models,
|
||||
has_more=len(annotation_hit_history_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -33,7 +34,6 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AudioTranscriptResponse(BaseModel):
|
||||
text: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
console_ns.models[AudioTranscriptResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
|
||||
@ -509,16 +509,19 @@ class ChatConversationApi(Resource):
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
case "all":
|
||||
pass
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AnnotationCountResponse(BaseModel):
|
||||
count: int = Field(description="Number of annotations")
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
MessageFeedbackPayload,
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -232,7 +241,7 @@ class ChatMessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@ -358,7 +367,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
console_ns.models[AnnotationCountResponse.__name__],
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -378,9 +387,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
console_ns.models[SuggestedQuestionsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@ -430,7 +437,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@ -2,9 +2,11 @@ import logging
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthDataSourceResponse(BaseModel):
|
||||
data: str = Field(description="Authorization URL or 'internal' for internal setup")
|
||||
|
||||
|
||||
class OAuthDataSourceBindingResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
class OAuthDataSourceSyncResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
OAuthDataSourceResponse,
|
||||
OAuthDataSourceBindingResponse,
|
||||
OAuthDataSourceSyncResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(
|
||||
@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
console_ns.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
console_ns.models[OAuthDataSourceResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceBindingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceSyncResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
|
||||
@ -2,10 +2,11 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
code: str | None = Field(default=None, description="Error code if account not found")
|
||||
|
||||
|
||||
class ForgotPasswordCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether code is valid")
|
||||
email: EmailStr = Field(description="Email address")
|
||||
token: str = Field(description="New reset token")
|
||||
|
||||
|
||||
class ForgotPasswordResetResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ForgotPasswordSendPayload,
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordEmailResponse,
|
||||
ForgotPasswordCheckResponse,
|
||||
ForgotPasswordResetResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.String(description="Reset token"),
|
||||
"code": fields.String(description="Error code if account not found"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordEmailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
"email": fields.String(description="Email address"),
|
||||
"token": fields.String(description="New reset token"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordCheckResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[ForgotPasswordResetResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
|
||||
@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
@ -157,9 +157,8 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
@ -167,23 +166,24 @@ class DataSourceApi(Resource):
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
match action:
|
||||
case "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
case "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Update need_summary to True for documents that don't have it set
|
||||
# This handles the case where documents were created when summary_index_setting was disabled
|
||||
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
|
||||
|
||||
if documents_to_update:
|
||||
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
|
||||
DocumentService.update_documents_need_summary(
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids_to_update,
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
|
||||
@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -51,7 +52,7 @@ from fields.app_fields import (
|
||||
tag_fields,
|
||||
)
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.member_fields import build_simple_account_model
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import (
|
||||
conversation_variable_fields,
|
||||
pipeline_variable_fields,
|
||||
@ -103,7 +104,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
|
||||
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
|
||||
|
||||
simple_account_model = build_simple_account_model(console_ns)
|
||||
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
|
||||
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
|
||||
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
|
||||
|
||||
@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -1,58 +1,60 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
class FeatureResponse(BaseModel):
|
||||
features: FeatureModel = Field(description="Feature configuration object")
|
||||
@console_ns.route("/features")
|
||||
class FeatureApi(Resource):
|
||||
@console_ns.doc("get_tenant_features")
|
||||
@console_ns.doc(description="Get feature configuration for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
class SystemFeatureResponse(BaseModel):
|
||||
features: SystemFeatureModel = Field(description="System feature configuration object")
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@console_ns.doc("get_system_features")
|
||||
@console_ns.doc(description="Get system-wide feature configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
@console_router.get(
|
||||
"/features",
|
||||
response_model=FeatureResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get_tenant_features() -> FeatureResponse:
|
||||
"""Get feature configuration for current tenant."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
return FeatureResponse(features=FeatureService.get_features(current_tenant_id))
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/system-features",
|
||||
response_model=SystemFeatureResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_system_features() -> SystemFeatureResponse:
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated))
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
|
||||
@ -1,14 +1,27 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -32,129 +45,115 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
class TagResponse(BaseModel):
|
||||
id: str = Field(description="Tag ID")
|
||||
name: str = Field(description="Tag name")
|
||||
type: str = Field(description="Tag type")
|
||||
binding_count: int = Field(description="Number of bindings")
|
||||
|
||||
|
||||
class TagBindingResult(BaseModel):
|
||||
result: Literal["success"] = Field(description="Operation result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/tags",
|
||||
response_model=list[TagResponse],
|
||||
tags=["console"],
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def list_tags(query: TagListQueryParam) -> list[TagResponse]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
|
||||
|
||||
return [
|
||||
TagResponse(
|
||||
id=tag.id,
|
||||
name=tag.name,
|
||||
type=tag.type,
|
||||
binding_count=int(tag.binding_count),
|
||||
)
|
||||
for tag in tags
|
||||
]
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tags",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag(payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
return tags, 200
|
||||
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_router.patch(
|
||||
"/tags/<uuid:tag_id>",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id_str)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@console_router.delete(
|
||||
"/tags/<uuid:tag_id>",
|
||||
tags=["console"],
|
||||
status_code=204,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete_tag(tag_id: UUID) -> None:
|
||||
tag_id_str = str(tag_id)
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/create",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/remove",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -37,7 +38,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
@ -236,11 +243,11 @@ class AccountProfileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return current_user
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
@ -249,14 +256,14 @@ class AccountNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
@ -265,7 +272,7 @@ class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -273,7 +280,7 @@ class AccountAvatarApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
@ -333,7 +340,7 @@ class AccountPasswordApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -344,7 +351,7 @@ class AccountPasswordApi(Resource):
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/integrates")
|
||||
@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource):
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
class EndpointCreateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class PluginEndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class EndpointDeleteResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointUpdateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointEnableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointDisableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
EndpointListResponse,
|
||||
PluginEndpointListResponse,
|
||||
EndpointDeleteResponse,
|
||||
EndpointUpdateResponse,
|
||||
EndpointEnableResponse,
|
||||
EndpointDisableResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -91,9 +130,7 @@ class EndpointListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[EndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[PluginEndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointEnableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDisableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_enum_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
register_enum_models(console_ns, TenantAccountRole)
|
||||
|
||||
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
|
||||
|
||||
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
|
||||
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
|
||||
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
|
||||
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
@ -84,13 +79,15 @@ class MemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel):
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
|
||||
register_schema_models(
|
||||
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
# Define annotation list response model
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
"has_more": fields.Boolean,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"page": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
class AnnotationListApi(Resource):
|
||||
@service_api_ns.doc("list_annotations")
|
||||
@ -109,8 +95,12 @@ class AnnotationListApi(Resource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Annotations retrieved successfully",
|
||||
service_api_ns.models[AnnotationList.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
@ -118,13 +108,15 @@ class AnnotationListApi(Resource):
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
|
||||
return {
|
||||
"data": annotation_list,
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@ -135,13 +127,18 @@ class AnnotationListApi(Resource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
HTTPStatus.CREATED,
|
||||
"Annotation created successfully",
|
||||
service_api_ns.models[Annotation.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
404: "Annotation not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Annotation updated successfully",
|
||||
service_api_ns.models[Annotation.__name__],
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@service_api_ns.doc("delete_annotation")
|
||||
@service_api_ns.doc(description="Delete an annotation")
|
||||
|
||||
@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@ -114,6 +114,7 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
|
||||
|
||||
@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
|
||||
return tags, 200
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
|
||||
@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
user_id = None
|
||||
user_id = None
|
||||
match fetch_user_arg.fetch_from:
|
||||
case WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
case WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
case WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
@ -14,16 +14,17 @@ class AgentConfigManager:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy in {"cot", "react"}:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
match agent_strategy:
|
||||
case "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
case "cot" | "react":
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
case _:
|
||||
# old configs, try to detect default strategy
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
|
||||
@ -268,7 +268,7 @@ class WorkflowResponseConverter:
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
status=status.value,
|
||||
status=status,
|
||||
outputs=encoded_outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
@ -512,13 +512,13 @@ class WorkflowResponseConverter:
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
error_message = event.error
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
status = WorkflowNodeExecutionStatus.FAILED
|
||||
error_message = event.error
|
||||
else:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
error_message = event.error
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
@ -585,7 +585,7 @@ class WorkflowResponseConverter:
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.RETRY.value,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
error=event.error,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
|
||||
@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_type = DatasourceProviderType(args["datasource_type"])
|
||||
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||
)
|
||||
@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info.get("title", "untitled")
|
||||
elif datasource_type == "online_drive":
|
||||
name = datasource_info.get("name", "untitled")
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
name = datasource_info.get("name", "untitled")
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
name = datasource_info.get("title", "untitled")
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
name = datasource_info.get("name", "untitled")
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
|
||||
def _format_datasource_info_list(
|
||||
self,
|
||||
datasource_type: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
datasource_info_list: list[Mapping[str, Any]],
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"""
|
||||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
all_files: list[Mapping[str, Any]] = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
|
||||
@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
@ -231,7 +231,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
@ -398,7 +398,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
@ -462,7 +462,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
process_data_truncated: bool = False
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
@ -806,7 +806,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
@ -369,77 +369,78 @@ class IndexingRunner:
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
preview_texts = index_processor.generate_summary_preview(
|
||||
tenant_id, preview_texts, summary_index_setting, doc_language
|
||||
)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
) -> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
|
||||
return []
|
||||
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
if dataset_document.data_source_type == "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
match dataset_document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
|
||||
if file_detail:
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "website_crawl":
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
or "url" not in data_source_info
|
||||
or "job_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "website_crawl":
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
or "url" not in data_source_info
|
||||
or "job_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case _:
|
||||
return []
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
|
||||
@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
2. You must write in {language}. No language other than {language} should be used.
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
7. If there is not enough content to generate a meaningful summary,
|
||||
return an empty string without any explanation or prompt
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
|
||||
@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
preview_texts: List of preview details to generate summaries for
|
||||
summary_index_setting: Summary index configuration
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
summary, _ = self.generate_summary(
|
||||
tenant_id, preview.content, summary_index_setting, document_language=doc_language
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
summary, _ = self.generate_summary(
|
||||
tenant_id, preview.content, summary_index_setting, document_language=doc_language
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
document_language: Optional document language (e.g., "Chinese", "English")
|
||||
to ensure summary is generated in the correct language
|
||||
|
||||
Returns:
|
||||
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
|
||||
@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
|
||||
|
||||
# Import default summary prompt
|
||||
is_default_prompt = False
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
is_default_prompt = True
|
||||
|
||||
# Format prompt with document language only for default prompt
|
||||
# Custom prompts are used as-is to avoid interfering with user-defined templates
|
||||
# If document_language is provided, use it; otherwise, use "the same language as the input content"
|
||||
# This is especially important for image-only chunks where text is empty or minimal
|
||||
if is_default_prompt:
|
||||
language_for_prompt = document_language or "the same language as the input content"
|
||||
try:
|
||||
summary_prompt = summary_prompt.format(language=language_for_prompt)
|
||||
except KeyError:
|
||||
# If default prompt doesn't have {language} placeholder, use it as-is
|
||||
pass
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
|
||||
@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
|
||||
@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
case None:
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
|
||||
@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
# Try to get document language if document_id is available
|
||||
doc_language = None
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document and document.doc_language:
|
||||
doc_language = document.doc_language
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
doc_language=doc_language,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
doc_language: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
case "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
case _:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
|
||||
@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]):
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
||||
@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
Path(target_filepath).write_bytes(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
|
||||
@ -1,36 +1,69 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"answer": fields.Raw(attribute="content"),
|
||||
"hit_count": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
# 'account': fields.Nested(simple_account_fields, allow_null=True)
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
}
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"id": fields.String,
|
||||
"source": fields.String,
|
||||
"score": fields.Float,
|
||||
"question": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"match": fields.String(attribute="annotation_question"),
|
||||
"response": fields.String(attribute="annotation_content"),
|
||||
}
|
||||
|
||||
annotation_hit_history_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
|
||||
}
|
||||
class Annotation(ResponseModel):
|
||||
id: str
|
||||
question: str | None = None
|
||||
answer: str | None = Field(default=None, validation_alias="content")
|
||||
hit_count: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AnnotationList(ResponseModel):
|
||||
data: list[Annotation]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class AnnotationExportList(ResponseModel):
|
||||
data: list[Annotation]
|
||||
|
||||
|
||||
class AnnotationHitHistory(ResponseModel):
|
||||
id: str
|
||||
source: str | None = None
|
||||
score: float | None = None
|
||||
question: str | None = None
|
||||
created_at: int | None = None
|
||||
match: str | None = Field(default=None, validation_alias="annotation_question")
|
||||
response: str | None = Field(default=None, validation_alias="annotation_content")
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AnnotationHitHistoryList(ResponseModel):
|
||||
data: list[AnnotationHitHistory]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@ -8,5 +11,18 @@ simple_end_user_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class SimpleEndUser(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
is_anonymous: bool
|
||||
session_id: str | None = None
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import fields
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
|
||||
simple_account_fields = {
|
||||
"id": fields.String,
|
||||
@ -9,36 +14,78 @@ simple_account_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"is_password_set": fields.Boolean,
|
||||
"interface_language": fields.String,
|
||||
"interface_theme": fields.String,
|
||||
"timezone": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_login_ip": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
def _build_avatar_url(avatar: str | None) -> str | None:
|
||||
if avatar is None:
|
||||
return None
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return avatar
|
||||
return file_helpers.get_signed_file_url(avatar)
|
||||
|
||||
account_with_role_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_active_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"role": fields.String,
|
||||
"status": fields.String,
|
||||
}
|
||||
|
||||
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class SimpleAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class _AccountAvatar(ResponseModel):
|
||||
avatar: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def avatar_url(self) -> str | None:
|
||||
return _build_avatar_url(self.avatar)
|
||||
|
||||
|
||||
class Account(_AccountAvatar):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
is_password_set: bool
|
||||
interface_language: str | None = None
|
||||
interface_theme: str | None = None
|
||||
timezone: str | None = None
|
||||
last_login_at: int | None = None
|
||||
last_login_ip: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_login_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountWithRole(_AccountAvatar):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
last_login_at: int | None = None
|
||||
last_active_at: int | None = None
|
||||
created_at: int | None = None
|
||||
role: str
|
||||
status: str
|
||||
|
||||
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AccountWithRoleList(ResponseModel):
|
||||
accounts: list[AccountWithRole]
|
||||
|
||||
@ -1,12 +1,20 @@
|
||||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
class DataSetTag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
binding_count: str | None = None
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
build_workflow_run_for_archived_log_model,
|
||||
build_workflow_run_for_log_model,
|
||||
@ -25,17 +25,9 @@ workflow_app_log_partial_fields = {
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_app_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
|
||||
|
||||
|
||||
@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = {
|
||||
def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow archived log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_archived_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.12.0"
|
||||
version = "1.12.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@ -145,7 +145,7 @@ dev = [
|
||||
"types-openpyxl~=3.1.5",
|
||||
"types-pexpect~=4.9.0",
|
||||
"types-protobuf~=5.29.1",
|
||||
"types-psutil~=7.0.0",
|
||||
"types-psutil~=7.2.2",
|
||||
"types-psycopg2~=2.9.21",
|
||||
"types-pygments~=2.19.0",
|
||||
"types-pymysql~=1.1.0",
|
||||
|
||||
@ -327,6 +327,17 @@ class AccountService:
|
||||
@staticmethod
|
||||
def delete_account(account: Account):
|
||||
"""Delete account. This method only adds a task to the queue for deletion."""
|
||||
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
|
||||
from services.enterprise.account_deletion_sync import sync_account_deletion
|
||||
|
||||
sync_success = sync_account_deletion(account_id=account.id, source="account_deleted")
|
||||
if not sync_success:
|
||||
logger.warning(
|
||||
"Enterprise account deletion sync failed for account %s; proceeding with local deletion.",
|
||||
account.id,
|
||||
)
|
||||
|
||||
# Now proceed with async account deletion
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
@ -1230,6 +1241,19 @@ class TenantService:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
BillingService.clean_billing_info_cache(tenant.id)
|
||||
|
||||
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
|
||||
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
|
||||
|
||||
sync_success = sync_workspace_member_removal(
|
||||
workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed"
|
||||
)
|
||||
if not sync_success:
|
||||
logger.warning(
|
||||
"Enterprise workspace member removal sync failed: workspace_id=%s, member_id=%s",
|
||||
tenant.id,
|
||||
account.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
"""Update member role"""
|
||||
|
||||
@ -158,7 +158,7 @@ class AppAnnotationService:
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
)
|
||||
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return annotations.items, annotations.total
|
||||
return annotations.items, annotations.total or 0
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
@ -524,7 +524,7 @@ class AppAnnotationService:
|
||||
annotation_hit_histories = db.paginate(
|
||||
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total or 0
|
||||
|
||||
@classmethod
|
||||
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
|
||||
|
||||
@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
@ -1388,6 +1389,46 @@ class DocumentService:
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int:
|
||||
"""
|
||||
Update need_summary field for multiple documents.
|
||||
|
||||
This method handles the case where documents were created when summary_index_setting was disabled,
|
||||
and need to be updated when summary_index_setting is later enabled.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
document_ids: List of document IDs to update
|
||||
need_summary: Value to set for need_summary field (default: True)
|
||||
|
||||
Returns:
|
||||
Number of documents updated
|
||||
"""
|
||||
if not document_ids:
|
||||
return 0
|
||||
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
updated_count = (
|
||||
session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != "qa_model", # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Updated need_summary to %s for %d documents in dataset %s",
|
||||
need_summary,
|
||||
updated_count,
|
||||
dataset_id,
|
||||
)
|
||||
return updated_count
|
||||
|
||||
@staticmethod
|
||||
def get_document_download_url(document: Document) -> str:
|
||||
"""
|
||||
@ -2937,14 +2978,15 @@ class DocumentService:
|
||||
"""
|
||||
now = naive_utc_now()
|
||||
|
||||
if action == "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
elif action == "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
elif action == "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
elif action == "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
match action:
|
||||
case "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
case "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
case "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
case "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
|
||||
return None
|
||||
|
||||
@ -3581,56 +3623,57 @@ class SegmentService:
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
if action == "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
match action:
|
||||
case "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
case "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
|
||||
@classmethod
|
||||
def create_child_chunk(
|
||||
|
||||
115
api/services/enterprise/account_deletion_sync.py
Normal file
115
api/services/enterprise/account_deletion_sync.py
Normal file
@ -0,0 +1,115 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue"
|
||||
ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace"
|
||||
|
||||
|
||||
def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Queue an account deletion sync task to Redis.
|
||||
|
||||
Internal helper function. Do not call directly - use the public functions instead.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace/tenant ID to sync
|
||||
member_id: The member/account ID that was removed
|
||||
source: Source of the sync request (for debugging/tracking)
|
||||
|
||||
Returns:
|
||||
bool: True if task was queued successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
task = {
|
||||
"task_id": str(uuid.uuid4()),
|
||||
"workspace_id": workspace_id,
|
||||
"member_id": member_id,
|
||||
"retry_count": 0,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"source": source,
|
||||
"type": ACCOUNT_DELETION_SYNC_TASK_TYPE,
|
||||
}
|
||||
|
||||
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
|
||||
redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task))
|
||||
|
||||
logger.info(
|
||||
"Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s",
|
||||
workspace_id,
|
||||
member_id,
|
||||
task["task_id"],
|
||||
source,
|
||||
)
|
||||
return True
|
||||
|
||||
except (RedisError, TypeError) as e:
|
||||
logger.error(
|
||||
"Failed to queue account deletion sync for workspace %s, member %s: %s",
|
||||
workspace_id,
|
||||
member_id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
# Don't raise - we don't want to fail member deletion if queueing fails
|
||||
return False
|
||||
|
||||
|
||||
def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Sync a single workspace member removal (enterprise only).
|
||||
|
||||
Queues a task for the enterprise backend to reassign resources from the removed member.
|
||||
Handles enterprise edition check internally. Safe to call in community edition (no-op).
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace/tenant ID
|
||||
member_id: The member/account ID that was removed
|
||||
source: Source of the sync request (e.g., "workspace_member_removed")
|
||||
|
||||
Returns:
|
||||
bool: True if task was queued (or skipped in community), False if queueing failed
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return True
|
||||
|
||||
return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
|
||||
|
||||
|
||||
def sync_account_deletion(account_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Sync full account deletion across all workspaces (enterprise only).
|
||||
|
||||
Fetches all workspace memberships for the account and queues a sync task for each.
|
||||
Handles enterprise edition check internally. Safe to call in community edition (no-op).
|
||||
|
||||
Args:
|
||||
account_id: The account ID being deleted
|
||||
source: Source of the sync request (e.g., "account_deleted")
|
||||
|
||||
Returns:
|
||||
bool: True if all tasks were queued (or skipped in community), False if any queueing failed
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return True
|
||||
|
||||
# Fetch all workspaces the account belongs to
|
||||
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
|
||||
|
||||
# Queue sync task for each workspace
|
||||
success = True
|
||||
for join in workspace_joins:
|
||||
if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source):
|
||||
success = False
|
||||
|
||||
return success
|
||||
@ -174,6 +174,10 @@ class RagPipelineTransformService:
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
||||
# Copy summary_index_setting from dataset to knowledge_index node configuration
|
||||
if dataset.summary_index_setting:
|
||||
knowledge_configuration.summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
knowledge_configuration_dict.update(knowledge_configuration.model_dump())
|
||||
node["data"] = knowledge_configuration_dict
|
||||
return node
|
||||
|
||||
@ -49,11 +49,18 @@ class SummaryIndexService:
|
||||
# Use lazy import to avoid circular import
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get document language to ensure summary is generated in the correct language
|
||||
# This is especially important for image-only chunks where text is empty or minimal
|
||||
document_language = None
|
||||
if segment.document and segment.document.doc_language:
|
||||
document_language = segment.document.doc_language
|
||||
|
||||
summary_content, usage = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=segment.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_id=segment.id,
|
||||
document_language=document_language,
|
||||
)
|
||||
|
||||
if not summary_content:
|
||||
@ -558,6 +565,9 @@ class SummaryIndexService:
|
||||
)
|
||||
session.add(summary_record)
|
||||
|
||||
# Commit the batch created records
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_summary_record_error(
|
||||
segment: DocumentSegment,
|
||||
@ -762,7 +772,6 @@ class SummaryIndexService:
|
||||
dataset=dataset,
|
||||
status="not_started",
|
||||
)
|
||||
session.commit() # Commit initial records
|
||||
|
||||
summary_records = []
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ class TagService:
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(workflow_archive_log_id: str):
|
||||
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with session_factory.create_session() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
|
||||
@ -10,7 +10,10 @@ from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
from tasks.remove_app_and_related_data_task import (
|
||||
_delete_draft_variables,
|
||||
delete_draft_variables_batch,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = session.query(UploadFile).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesSessionCommit:
|
||||
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with offload files for session commit tests."""
|
||||
from core.variables.types import SegmentType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(upload_file1)
|
||||
session.add(upload_file2)
|
||||
session.flush()
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
session.add(var_file1)
|
||||
session.add(var_file2)
|
||||
session.flush()
|
||||
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(draft_var1)
|
||||
session.add(draft_var2)
|
||||
session.add(draft_var3)
|
||||
session.commit()
|
||||
|
||||
data = {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
yield data
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
|
||||
(UploadFile, [uf.id for uf in data["upload_files"]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def setup_commit_test_data(self, app_and_tenant):
|
||||
"""Create test data for session commit tests."""
|
||||
tenant, app = app_and_tenant
|
||||
variable_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
variables = []
|
||||
for i in range(10):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
variables.append(var)
|
||||
session.commit()
|
||||
variable_ids = [v.id for v in variables]
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"variable_ids": variable_ids,
|
||||
}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
|
||||
"""Test that session.begin() is used for automatic transaction management."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Since session.begin() is used, the transaction is automatically committed
|
||||
# when the with block exits successfully. We verify this by checking that
|
||||
# data is actually persisted.
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||
|
||||
# Verify all data was deleted (proves transaction was committed)
|
||||
with session_factory.create_session() as session:
|
||||
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
|
||||
assert deleted_count == 10
|
||||
assert remaining_count == 0
|
||||
|
||||
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
|
||||
"""Test that data is actually persisted to database after batch deletion with commits."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
variable_ids = data["variable_ids"]
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert initial_count == 10
|
||||
|
||||
# Perform deletion with small batch size to force multiple commits
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||
|
||||
assert deleted_count == 10
|
||||
|
||||
# Verify all data is deleted in a new session (proves commits worked)
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert final_count == 0
|
||||
|
||||
# Verify specific IDs are deleted
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = (
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
|
||||
)
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
|
||||
"""Test session behavior when deleting from an empty dataset."""
|
||||
nonexistent_app_id = str(uuid.uuid4())
|
||||
|
||||
# Should not raise any errors and should return 0
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_session_commit_with_single_batch(self, setup_commit_test_data):
|
||||
"""Test that commit happens correctly when all data fits in a single batch."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert initial_count == 10
|
||||
|
||||
# Delete all in a single batch
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
|
||||
assert deleted_count == 10
|
||||
|
||||
# Verify data is persisted
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert final_count == 0
|
||||
|
||||
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
|
||||
"""Test that invalid batch size raises ValueError."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, batch_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, batch_size=-1)
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that session commits correctly when cleaning up offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete variables with offload data
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count == 3
|
||||
|
||||
# Verify all data is persisted (deleted) in new session
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_after == 0
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage cleanup was called
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@ -1016,7 +1016,7 @@ class TestAccountService:
|
||||
|
||||
def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test account deletion (should add task to queue).
|
||||
Test account deletion (should add task to queue and sync to enterprise).
|
||||
"""
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
@ -1034,10 +1034,18 @@ class TestAccountService:
|
||||
password=password,
|
||||
)
|
||||
|
||||
with patch("services.account_service.delete_account_task") as mock_delete_task:
|
||||
with (
|
||||
patch("services.account_service.delete_account_task") as mock_delete_task,
|
||||
patch("services.enterprise.account_deletion_sync.sync_account_deletion") as mock_sync,
|
||||
):
|
||||
mock_sync.return_value = True
|
||||
|
||||
# Delete account
|
||||
AccountService.delete_account(account)
|
||||
|
||||
# Verify sync was called
|
||||
mock_sync.assert_called_once_with(account_id=account.id, source="account_deleted")
|
||||
|
||||
# Verify task was added to queue
|
||||
mock_delete_task.delay.assert_called_once_with(account.id)
|
||||
|
||||
@ -1716,7 +1724,7 @@ class TestTenantService:
|
||||
|
||||
def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful member removal from tenant.
|
||||
Test successful member removal from tenant (should sync to enterprise).
|
||||
"""
|
||||
fake = Faker()
|
||||
tenant_name = fake.company()
|
||||
@ -1751,7 +1759,15 @@ class TestTenantService:
|
||||
TenantService.create_tenant_member(tenant, member_account, role="normal")
|
||||
|
||||
# Remove member
|
||||
TenantService.remove_member_from_tenant(tenant, member_account, owner_account)
|
||||
with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync:
|
||||
mock_sync.return_value = True
|
||||
|
||||
TenantService.remove_member_from_tenant(tenant, member_account, owner_account)
|
||||
|
||||
# Verify sync was called
|
||||
mock_sync.assert_called_once_with(
|
||||
workspace_id=tenant.id, member_id=member_account.id, source="workspace_member_removed"
|
||||
)
|
||||
|
||||
# Verify member was removed
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -1,291 +0,0 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
from services.feature_service import FeatureModel, SystemFeatureModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Creates a Flask application instance configured for testing.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
# Initialize the database with the app
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
"""
|
||||
Automatic fixture to patch 'builtins.MethodView'.
|
||||
|
||||
Why this is needed:
|
||||
The official legacy codebase contains a global patch in its initialization logic:
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView
|
||||
|
||||
Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly
|
||||
rely on 'MethodView' being available in the global builtins namespace.
|
||||
|
||||
Refactoring Note:
|
||||
While patching builtins is generally discouraged due to global side effects,
|
||||
this fixture reproduces the production environment's state to ensure tests are realistic.
|
||||
We use 'monkeypatch' to ensure that this change is undone after the test finishes,
|
||||
keeping other tests isolated.
|
||||
"""
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
# 'raising=False' allows us to set an attribute that doesn't exist yet
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Helper Functions for Fixture Complexity Reduction
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
"""
|
||||
Creates a fresh, isolated router instance to prevent route pollution.
|
||||
"""
|
||||
import controllers.fastopenapi
|
||||
|
||||
# Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies
|
||||
RouterClass = type(controllers.fastopenapi.console_router)
|
||||
return RouterClass()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
"""
|
||||
Context manager that applies all necessary patches for:
|
||||
1. The console_router (redirecting to our isolated temp_router)
|
||||
2. Authentication decorators (disabling them with no-ops)
|
||||
3. User/Account loaders (mocking authenticated state)
|
||||
"""
|
||||
|
||||
def noop(f):
|
||||
return f
|
||||
|
||||
# We patch the SOURCE of the decorators/functions, not the destination module.
|
||||
# This ensures that when 'controllers.console.feature' imports them, it gets the mocks.
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.cloud_utm_record", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")),
|
||||
patch("libs.login.current_user", MagicMock(is_authenticated=True)),
|
||||
):
|
||||
# Explicitly reload ext_fastopenapi to ensure it uses the patched console_router
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
"""
|
||||
Forces a reload of the specified module and handles sys.modules aliasing.
|
||||
|
||||
Why reload?
|
||||
Python decorators (like @route, @login_required) run at IMPORT time.
|
||||
To apply our patches (mocks/no-ops) to these decorators, we must re-import
|
||||
the module while the patches are active.
|
||||
|
||||
Why alias?
|
||||
If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import
|
||||
it as 'controllers...', Python treats them as two separate modules. This causes:
|
||||
1. Double execution of decorators (registering routes twice -> AssertionError).
|
||||
2. Type mismatch errors (Class A from module X is not Class A from module Y).
|
||||
|
||||
This function ensures both names point to the SAME loaded module instance.
|
||||
"""
|
||||
# 1. Clean existing entries to force re-import
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
# 2. Import the module (triggering decorators with active patches)
|
||||
module = importlib.import_module(target_module)
|
||||
|
||||
# 3. Alias the module in sys.modules to prevent double loading
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
"""
|
||||
Removes the module and its alias from sys.modules to prevent side effects
|
||||
on other tests.
|
||||
"""
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_module_env():
|
||||
"""
|
||||
Sets up a mocked environment for the feature module.
|
||||
|
||||
This fixture orchestrates:
|
||||
1. Creating an isolated router.
|
||||
2. Patching authentication and global dependencies.
|
||||
3. Reloading the controller module to apply patches to decorators.
|
||||
4. cleaning up sys.modules afterwards.
|
||||
"""
|
||||
target_module = "controllers.console.feature"
|
||||
alias_module = "api.controllers.console.feature"
|
||||
|
||||
# 1. Prepare isolated router
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
# 2. Apply patches
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
# 3. Reload module to register routes on the temp_router
|
||||
feature_module = _force_reload_module(target_module, alias_module)
|
||||
|
||||
yield feature_module
|
||||
|
||||
finally:
|
||||
# 4. Teardown: Clean up sys.modules
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "service_mock_path", "mock_model_instance", "json_key"),
|
||||
[
|
||||
(
|
||||
"/console/api/features",
|
||||
"controllers.console.feature.FeatureService.get_features",
|
||||
FeatureModel(can_replace_logo=True),
|
||||
"features",
|
||||
),
|
||||
(
|
||||
"/console/api/system-features",
|
||||
"controllers.console.feature.FeatureService.get_system_features",
|
||||
SystemFeatureModel(enable_marketplace=True),
|
||||
"features",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key):
|
||||
"""
|
||||
Tests that the feature APIs return a 200 OK status and correct JSON structure.
|
||||
"""
|
||||
# Patch the service layer to return our mock model instance
|
||||
with patch(service_mock_path, return_value=mock_model_instance):
|
||||
# Initialize the API extension
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
client = app.test_client()
|
||||
response = client.get(url)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}"
|
||||
|
||||
# Verify the JSON response matches the Pydantic model dump
|
||||
expected_data = mock_model_instance.model_dump(mode="json")
|
||||
assert response.get_json() == {json_key: expected_data}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "service_mock_path"),
|
||||
[
|
||||
("/console/api/features", "controllers.console.feature.FeatureService.get_features"),
|
||||
("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"),
|
||||
],
|
||||
)
|
||||
def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path):
|
||||
"""
|
||||
Tests how the application handles Service layer errors.
|
||||
|
||||
Note: When an exception occurs in the view, it is typically caught by the framework
|
||||
(Flask or the OpenAPI wrapper) and converted to a 500 error response.
|
||||
This test verifies that the application returns a 500 status code.
|
||||
"""
|
||||
# Simulate a service failure
|
||||
with patch(service_mock_path, side_effect=ValueError("Service Failure")):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# When an exception occurs in the view, it is typically caught by the framework
|
||||
# (Flask or the OpenAPI wrapper) and converted to a 500 error response.
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 500
|
||||
# Check if the error details are exposed in the response (depends on error handler config)
|
||||
# We accept either generic 500 or the specific error message
|
||||
assert "Service Failure" in response.text or "Internal Server Error" in response.text
|
||||
|
||||
|
||||
def test_system_features_unauthenticated(app, mock_feature_module_env):
|
||||
"""
|
||||
Tests that /console/api/system-features endpoint works without authentication.
|
||||
|
||||
This test verifies the try-except block in get_system_features that handles
|
||||
unauthenticated requests by passing is_authenticated=False to the service layer.
|
||||
"""
|
||||
feature_module = mock_feature_module_env
|
||||
|
||||
# Override the behavior of the current_user mock
|
||||
# The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user'
|
||||
# refers to that same Mock object.
|
||||
mock_user = feature_module.current_user
|
||||
|
||||
# Simulate property access raising Unauthorized
|
||||
# Note: We must reset side_effect if it was set, or set it here.
|
||||
# The fixture initialized it as MagicMock(is_authenticated=True).
|
||||
# We want type(mock_user).is_authenticated to raise Unauthorized.
|
||||
type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized)
|
||||
|
||||
# Patch the service layer for this specific test
|
||||
with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service:
|
||||
# Setup mock service return value
|
||||
mock_model = SystemFeatureModel(enable_marketplace=True)
|
||||
mock_service.return_value = mock_model
|
||||
|
||||
# Initialize app
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/system-features")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
|
||||
# Verify service was called with is_authenticated=False
|
||||
mock_service.assert_called_once_with(is_authenticated=False)
|
||||
|
||||
# Verify response body
|
||||
expected_data = mock_model.model_dump(mode="json")
|
||||
assert response.get_json() == {"features": expected_data}
|
||||
@ -1,222 +0,0 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
import controllers.fastopenapi
|
||||
|
||||
router_class = type(controllers.fastopenapi.console_router)
|
||||
return router_class()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
def noop(func):
|
||||
return func
|
||||
|
||||
default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.edit_permission_required", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")),
|
||||
patch("configs.dify_config.EDITION", "CLOUD"),
|
||||
):
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _dedupe_routes(router):
|
||||
seen = set()
|
||||
unique_routes = []
|
||||
for path, method, endpoint in reversed(router.get_routes()):
|
||||
key = (path, method, endpoint.__name__)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_routes.append((path, method, endpoint))
|
||||
router._routes = list(reversed(unique_routes))
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tags_module_env():
|
||||
target_module = "controllers.console.tag.tags"
|
||||
alias_module = "api.controllers.console.tag.tags"
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
tags_module = _force_reload_module(target_module, alias_module)
|
||||
_dedupe_routes(temp_router)
|
||||
yield tags_module
|
||||
finally:
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
def test_list_tags_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2)
|
||||
with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/tags?type=app&keyword=Alpha")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == [
|
||||
{"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2},
|
||||
]
|
||||
|
||||
|
||||
def test_create_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-2", name="Beta", type="app")
|
||||
with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"})
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-2",
|
||||
"name": "Beta",
|
||||
"type": "app",
|
||||
"binding_count": 0,
|
||||
}
|
||||
mock_save.assert_called_once_with({"name": "Beta", "type": "app"})
|
||||
|
||||
|
||||
def test_update_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-3", name="Gamma", type="app")
|
||||
with (
|
||||
patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update,
|
||||
patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4),
|
||||
):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.patch(
|
||||
"/console/api/tags/11111111-1111-1111-1111-111111111111",
|
||||
json={"name": "Gamma", "type": "app"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-3",
|
||||
"name": "Gamma",
|
||||
"type": "app",
|
||||
"binding_count": 4,
|
||||
}
|
||||
mock_update.assert_called_once_with(
|
||||
{"name": "Gamma", "type": "app"},
|
||||
"11111111-1111-1111-1111-111111111111",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
|
||||
def test_create_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/create", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_bind.assert_called_once_with(payload)
|
||||
|
||||
|
||||
def test_delete_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/remove", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_unbind.assert_called_once_with(payload)
|
||||
@ -0,0 +1,364 @@
|
||||
"""Endpoint tests for controllers.console.workspace.tool_providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
_CONTROLLER_PATCHERS: list[patch] = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
module_name = "controllers.console.workspace.tool_providers"
|
||||
global _CONTROLLER_MODULE
|
||||
if _CONTROLLER_MODULE is None:
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
patch_targets = [
|
||||
("libs.login.login_required", _noop),
|
||||
("controllers.console.wraps.setup_required", _noop),
|
||||
("controllers.console.wraps.account_initialization_required", _noop),
|
||||
("controllers.console.wraps.is_admin_or_owner_required", _noop),
|
||||
("controllers.console.wraps.enterprise_license_required", _noop),
|
||||
]
|
||||
for target, value in patch_targets:
|
||||
patcher = patch(target, value)
|
||||
patcher.start()
|
||||
_CONTROLLER_PATCHERS.append(patcher)
|
||||
monkeypatch.setenv("DIFY_SETUP_READY", "true")
|
||||
with _mock_db():
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
|
||||
module = _CONTROLLER_MODULE
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
# Ensure decorators that consult deployment edition do not reach the database.
|
||||
global _WRAPS_MODULE
|
||||
wraps_module = importlib.import_module("controllers.console.wraps")
|
||||
_WRAPS_MODULE = wraps_module
|
||||
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
|
||||
return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None)
|
||||
|
||||
|
||||
def _set_current_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
controller_module: ModuleType,
|
||||
user: SimpleNamespace,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def _getter():
|
||||
return user, tenant_id
|
||||
|
||||
user.current_tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
|
||||
if _WRAPS_MODULE is not None:
|
||||
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "_get_user", lambda: user)
|
||||
|
||||
|
||||
def test_tool_provider_list_calls_service_with_query(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
|
||||
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
|
||||
|
||||
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
|
||||
response = controller_module.ToolProviderListApi().get()
|
||||
|
||||
assert response == [{"provider": "builtin"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
|
||||
|
||||
|
||||
def test_builtin_provider_add_passes_payload(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value={"status": "ok"})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
|
||||
|
||||
payload = {
|
||||
"credentials": {"api_key": "sk-test"},
|
||||
"name": "MyTool",
|
||||
"type": controller_module.CredentialType.API_KEY,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/openai/add",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
|
||||
|
||||
assert response == {"status": "ok"}
|
||||
service_mock.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
provider="openai",
|
||||
credentials={"api_key": "sk-test"},
|
||||
name="MyTool",
|
||||
api_type=controller_module.CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
|
||||
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-789")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
|
||||
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/my-provider/tools",
|
||||
method="GET",
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
|
||||
|
||||
assert response == [{"name": "tool-a"}]
|
||||
service_mock.assert_called_once_with("tenant-789", "my-provider")
|
||||
|
||||
|
||||
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-9")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
service_mock = MagicMock(return_value={"info": True})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
|
||||
assert resp == {"info": True}
|
||||
service_mock.assert_called_once_with("tenant-9", "demo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-cred")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
|
||||
service_mock = MagicMock(return_value=[{"cred": 1}])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"get_builtin_tool_provider_credentials",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/creds", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
|
||||
|
||||
assert resp == [{"cred": 1}]
|
||||
service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo")
|
||||
|
||||
|
||||
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
|
||||
service_mock = MagicMock(return_value={"schema": "ok"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
|
||||
|
||||
with app.test_request_context("/remote?url=https://example.com/"):
|
||||
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
|
||||
|
||||
assert resp == {"schema": "ok"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
|
||||
|
||||
|
||||
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
|
||||
service_mock = MagicMock(return_value=[{"tool": "t"}])
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
|
||||
|
||||
with app.test_request_context("/tools?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderListToolsApi().get()
|
||||
|
||||
assert resp == [{"tool": "t"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
|
||||
|
||||
|
||||
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
|
||||
service_mock = MagicMock(return_value={"provider": "foo"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
|
||||
|
||||
with app.test_request_context("/get?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderGetApi().get()
|
||||
|
||||
assert resp == {"provider": "foo"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-13")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
|
||||
service_mock = MagicMock(return_value={"schema": True})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_provider_credentials_schema",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/schema", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
|
||||
provider="demo", credential_type="api-key"
|
||||
)
|
||||
|
||||
assert resp == {"schema": True}
|
||||
service_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
|
||||
tool_service = MagicMock(return_value={"wf": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_tool_id",
|
||||
tool_service,
|
||||
)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000001"
|
||||
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"wf": 1}
|
||||
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
|
||||
service_mock = MagicMock(return_value={"app": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_app_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
app_id = "00000000-0000-0000-0000-000000000002"
|
||||
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"app": 1}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
|
||||
|
||||
|
||||
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
|
||||
service_mock = MagicMock(return_value=[{"id": 1}])
|
||||
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000003"
|
||||
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderListToolApi().get()
|
||||
|
||||
assert resp == [{"id": 1}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
|
||||
|
||||
|
||||
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/builtin"):
|
||||
resp = controller_module.ToolBuiltinListApi().get()
|
||||
|
||||
assert resp == [{"name": "builtin"}]
|
||||
|
||||
|
||||
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-api")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.ApiToolManageService,
|
||||
"list_api_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/api"):
|
||||
resp = controller_module.ToolApiListApi().get()
|
||||
|
||||
assert resp == [{"name": "api"}]
|
||||
|
||||
|
||||
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"list_tenant_workflow_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/workflow"):
|
||||
resp = controller_module.ToolWorkflowListApi().get()
|
||||
|
||||
assert resp == [{"name": "wf"}]
|
||||
|
||||
|
||||
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-label")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
|
||||
|
||||
with app.test_request_context("/tool-labels"):
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
@ -0,0 +1,276 @@
|
||||
"""Unit tests for account deletion synchronization.
|
||||
|
||||
This test module verifies the enterprise account deletion sync functionality,
|
||||
including Redis queuing, error handling, and community vs enterprise behavior.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis import RedisError
|
||||
|
||||
from services.enterprise.account_deletion_sync import (
|
||||
_queue_task,
|
||||
sync_account_deletion,
|
||||
sync_workspace_member_removal,
|
||||
)
|
||||
|
||||
|
||||
class TestQueueTask:
|
||||
"""Unit tests for the _queue_task helper function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self):
|
||||
"""Mock redis_client for testing."""
|
||||
with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis:
|
||||
yield mock_redis
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uuid(self):
|
||||
"""Mock UUID generation for predictable task IDs."""
|
||||
with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen:
|
||||
mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234")
|
||||
yield mock_uuid_gen
|
||||
|
||||
def test_queue_task_success(self, mock_redis_client, mock_uuid):
|
||||
"""Test successful task queueing to Redis."""
|
||||
# Arrange
|
||||
workspace_id = "ws-123"
|
||||
member_id = "member-456"
|
||||
source = "test_source"
|
||||
|
||||
# Act
|
||||
result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_redis_client.lpush.assert_called_once()
|
||||
|
||||
# Verify the task payload structure
|
||||
call_args = mock_redis_client.lpush.call_args[0]
|
||||
assert call_args[0] == "enterprise:member:sync:queue"
|
||||
|
||||
import json
|
||||
|
||||
task_data = json.loads(call_args[1])
|
||||
assert task_data["workspace_id"] == workspace_id
|
||||
assert task_data["member_id"] == member_id
|
||||
assert task_data["source"] == source
|
||||
assert task_data["type"] == "sync_member_deletion_from_workspace"
|
||||
assert task_data["retry_count"] == 0
|
||||
assert "task_id" in task_data
|
||||
assert "created_at" in task_data
|
||||
|
||||
def test_queue_task_redis_error(self, mock_redis_client, caplog):
|
||||
"""Test handling of Redis connection errors."""
|
||||
# Arrange
|
||||
mock_redis_client.lpush.side_effect = RedisError("Connection failed")
|
||||
|
||||
# Act
|
||||
result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
assert "Failed to queue account deletion sync" in caplog.text
|
||||
|
||||
def test_queue_task_type_error(self, mock_redis_client, caplog):
|
||||
"""Test handling of JSON serialization errors."""
|
||||
# Arrange
|
||||
mock_redis_client.lpush.side_effect = TypeError("Cannot serialize")
|
||||
|
||||
# Act
|
||||
result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
assert "Failed to queue account deletion sync" in caplog.text
|
||||
|
||||
|
||||
class TestSyncWorkspaceMemberRemoval:
|
||||
"""Unit tests for sync_workspace_member_removal function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_task(self):
|
||||
"""Mock _queue_task for testing."""
|
||||
with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue:
|
||||
mock_queue.return_value = True
|
||||
yield mock_queue
|
||||
|
||||
def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task):
|
||||
"""Test sync when ENTERPRISE_ENABLED is True."""
|
||||
# Arrange
|
||||
workspace_id = "ws-123"
|
||||
member_id = "member-456"
|
||||
source = "workspace_member_removed"
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source)
|
||||
|
||||
def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task):
|
||||
"""Test sync when ENTERPRISE_ENABLED is False (community edition)."""
|
||||
# Arrange
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
# Act
|
||||
result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_queue_task.assert_not_called()
|
||||
|
||||
def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task):
|
||||
"""Test handling of queue task failures."""
|
||||
# Arrange
|
||||
mock_queue_task.return_value = False
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestSyncAccountDeletion:
|
||||
"""Unit tests for sync_account_deletion function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session for testing."""
|
||||
with patch("services.enterprise.account_deletion_sync.db.session") as mock_session:
|
||||
yield mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_task(self):
|
||||
"""Mock _queue_task for testing."""
|
||||
with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue:
|
||||
mock_queue.return_value = True
|
||||
yield mock_queue
|
||||
|
||||
def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task):
|
||||
"""Test sync when ENTERPRISE_ENABLED is False (community edition)."""
|
||||
# Arrange
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
# Act
|
||||
result = sync_account_deletion(account_id="acc-123", source="account_deleted")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_queue_task.assert_not_called()
|
||||
|
||||
def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task):
|
||||
"""Test sync for account with multiple workspace memberships."""
|
||||
# Arrange
|
||||
account_id = "acc-123"
|
||||
|
||||
# Mock workspace joins
|
||||
mock_join1 = MagicMock()
|
||||
mock_join1.tenant_id = "tenant-1"
|
||||
mock_join2 = MagicMock()
|
||||
mock_join2.tenant_id = "tenant-2"
|
||||
mock_join3 = MagicMock()
|
||||
mock_join3.tenant_id = "tenant-3"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = sync_account_deletion(account_id=account_id, source="account_deleted")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert mock_queue_task.call_count == 3
|
||||
|
||||
# Verify each workspace was queued
|
||||
mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted")
|
||||
mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted")
|
||||
mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted")
|
||||
|
||||
def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task):
|
||||
"""Test sync for account with no workspace memberships."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = sync_account_deletion(account_id="acc-123", source="account_deleted")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_queue_task.assert_not_called()
|
||||
|
||||
def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task):
|
||||
"""Test sync when some tasks fail to queue."""
|
||||
# Arrange
|
||||
account_id = "acc-123"
|
||||
|
||||
# Mock workspace joins
|
||||
mock_join1 = MagicMock()
|
||||
mock_join1.tenant_id = "tenant-1"
|
||||
mock_join2 = MagicMock()
|
||||
mock_join2.tenant_id = "tenant-2"
|
||||
mock_join3 = MagicMock()
|
||||
mock_join3.tenant_id = "tenant-3"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Mock queue_task to fail for second workspace
|
||||
def queue_side_effect(workspace_id, member_id, source):
|
||||
return workspace_id != "tenant-2"
|
||||
|
||||
mock_queue_task.side_effect = queue_side_effect
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = sync_account_deletion(account_id=account_id, source="account_deleted")
|
||||
|
||||
# Assert
|
||||
assert result is False # Should return False if any task fails
|
||||
assert mock_queue_task.call_count == 3
|
||||
|
||||
def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task):
|
||||
"""Test sync when all tasks fail to queue."""
|
||||
# Arrange
|
||||
mock_join = MagicMock()
|
||||
mock_join.tenant_id = "tenant-1"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter_by.return_value.all.return_value = [mock_join]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_queue_task.return_value = False
|
||||
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = sync_account_deletion(account_id="acc-123", source="account_deleted")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_queue_task.assert_called_once()
|
||||
@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
|
||||
mock_query.where.return_value = mock_delete_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
delete_func("log-1")
|
||||
delete_func(mock_db.session, "log-1")
|
||||
|
||||
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
10
api/uv.lock
generated
10
api/uv.lock
generated
@ -1368,7 +1368,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.12.0"
|
||||
version = "1.12.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
@ -1707,7 +1707,7 @@ dev = [
|
||||
{ name = "types-openpyxl", specifier = "~=3.1.5" },
|
||||
{ name = "types-pexpect", specifier = "~=4.9.0" },
|
||||
{ name = "types-protobuf", specifier = "~=5.29.1" },
|
||||
{ name = "types-psutil", specifier = "~=7.0.0" },
|
||||
{ name = "types-psutil", specifier = "~=7.2.2" },
|
||||
{ name = "types-psycopg2", specifier = "~=2.9.21" },
|
||||
{ name = "types-pygments", specifier = "~=2.19.0" },
|
||||
{ name = "types-pymysql", specifier = "~=1.1.0" },
|
||||
@ -6508,11 +6508,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-psutil"
|
||||
version = "7.0.0.20251116"
|
||||
version = "7.2.2.20260130"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
set -euxo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.12.0
|
||||
image: langgenius/dify-web:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@ -662,13 +662,14 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./volumes/iris:/durable
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@ -712,7 +712,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -754,7 +754,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -793,7 +793,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -823,7 +823,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.12.0
|
||||
image: langgenius/dify-web:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@ -1353,13 +1353,14 @@ services:
|
||||
- "${IRIS_SUPER_SERVER_PORT:-1972}:1972"
|
||||
- "${IRIS_WEB_SERVER_PORT:-52773}:52773"
|
||||
volumes:
|
||||
- ./volumes/iris:/opt/iris
|
||||
- ./volumes/iris:/durable
|
||||
- ./iris/iris-init.script:/iris-init.script
|
||||
- ./iris/docker-entrypoint.sh:/custom-entrypoint.sh
|
||||
entrypoint: ["/custom-entrypoint.sh"]
|
||||
tty: true
|
||||
environment:
|
||||
TZ: ${IRIS_TIMEZONE:-UTC}
|
||||
ISC_DATA_DIRECTORY: /durable/iris
|
||||
|
||||
# Oracle vector database
|
||||
oracle:
|
||||
|
||||
@ -1,15 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# IRIS configuration flag file
|
||||
IRIS_CONFIG_DONE="/opt/iris/.iris-configured"
|
||||
# IRIS configuration flag file (stored in durable directory to persist with data)
|
||||
IRIS_CONFIG_DONE="/durable/.iris-configured"
|
||||
|
||||
# Function to wait for IRIS to be ready
|
||||
wait_for_iris() {
|
||||
echo "Waiting for IRIS to be ready..."
|
||||
local max_attempts=30
|
||||
local attempt=1
|
||||
while [ "$attempt" -le "$max_attempts" ]; do
|
||||
if iris qlist IRIS 2>/dev/null | grep -q "running"; then
|
||||
echo "IRIS is ready."
|
||||
return 0
|
||||
fi
|
||||
echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..."
|
||||
sleep 2
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
echo "ERROR: IRIS failed to start within expected time." >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Function to configure IRIS
|
||||
configure_iris() {
|
||||
echo "Configuring IRIS for first-time setup..."
|
||||
|
||||
# Wait for IRIS to be fully started
|
||||
sleep 5
|
||||
wait_for_iris
|
||||
|
||||
# Execute the initialization script
|
||||
iris session IRIS < /iris-init.script
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import Cookies from 'js-cookie'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import { parseAsBoolean, useQueryState } from 'nuqs'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import {
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
@ -28,7 +28,7 @@ export const AppInitializer = ({
|
||||
const [init, setInit] = useState(false)
|
||||
const [oauthNewUser, setOauthNewUser] = useQueryState(
|
||||
'oauth_new_user',
|
||||
parseAsString.withOptions({ history: 'replace' }),
|
||||
parseAsBoolean.withOptions({ history: 'replace' }),
|
||||
)
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
@ -46,7 +46,7 @@ export const AppInitializer = ({
|
||||
(async () => {
|
||||
const action = searchParams.get('action')
|
||||
|
||||
if (oauthNewUser === 'true') {
|
||||
if (oauthNewUser) {
|
||||
let utmInfo = null
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
if (utmInfoStr) {
|
||||
|
||||
@ -62,19 +62,19 @@ const AppCard = ({
|
||||
{app.description}
|
||||
</div>
|
||||
</div>
|
||||
{canCreate && (
|
||||
{(canCreate || isTrialApp) && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', isTrialApp && 'grid-cols-2')}>
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
{isTrialApp && (
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', canCreate && 'grid-cols-2')}>
|
||||
{canCreate && (
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type { App } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
@ -13,8 +14,8 @@ import { getRedirection } from '@/utils/app-redirection'
|
||||
import CreateAppModal from './index'
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounceFn: (fn: (...args: any[]) => any) => {
|
||||
const run = (...args: any[]) => fn(...args)
|
||||
useDebounceFn: <T extends (...args: unknown[]) => unknown>(fn: T) => {
|
||||
const run = (...args: Parameters<T>) => fn(...args)
|
||||
const cancel = vi.fn()
|
||||
const flush = vi.fn()
|
||||
return { run, cancel, flush }
|
||||
@ -83,7 +84,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseRouter.mockReturnValue({ push: mockPush } as any)
|
||||
mockUseRouter.mockReturnValue({ push: mockPush } as unknown as ReturnType<typeof useRouter>)
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: AppModeEnum.ADVANCED_CHAT,
|
||||
@ -92,10 +93,10 @@ describe('CreateAppModal', () => {
|
||||
reset: {},
|
||||
},
|
||||
enableBilling: true,
|
||||
} as any)
|
||||
} as unknown as ReturnType<typeof useProviderContext>)
|
||||
mockUseAppContext.mockReturnValue({
|
||||
isCurrentWorkspaceEditor: true,
|
||||
} as any)
|
||||
} as unknown as ReturnType<typeof useAppContext>)
|
||||
mockSetItem.mockClear()
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: {
|
||||
@ -118,13 +119,13 @@ describe('CreateAppModal', () => {
|
||||
})
|
||||
|
||||
it('creates an app, notifies success, and fires callbacks', async () => {
|
||||
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
|
||||
mockCreateApp.mockResolvedValue(mockApp as any)
|
||||
const mockApp: Partial<App> = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
|
||||
mockCreateApp.mockResolvedValue(mockApp as App)
|
||||
const { onClose, onSuccess } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||
name: 'My App',
|
||||
@ -152,7 +153,7 @@ describe('CreateAppModal', () => {
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||
|
||||
@ -216,13 +216,22 @@ describe('image-uploader utils', () => {
|
||||
type FileCallback = (file: MockFile) => void
|
||||
type EntriesCallback = (entries: FileSystemEntry[]) => void
|
||||
|
||||
// Helper to create mock FileSystemEntry with required properties
|
||||
const createMockEntry = (props: {
|
||||
isFile: boolean
|
||||
isDirectory: boolean
|
||||
name?: string
|
||||
file?: (callback: FileCallback) => void
|
||||
createReader?: () => { readEntries: (callback: EntriesCallback) => void }
|
||||
}): FileSystemEntry => props as unknown as FileSystemEntry
|
||||
|
||||
it('should resolve with file array for file entry', async () => {
|
||||
const mockFile: MockFile = { name: 'test.png' }
|
||||
const mockEntry = {
|
||||
const mockEntry = createMockEntry({
|
||||
isFile: true,
|
||||
isDirectory: false,
|
||||
file: (callback: FileCallback) => callback(mockFile),
|
||||
}
|
||||
})
|
||||
|
||||
const result = await traverseFileEntry(mockEntry)
|
||||
expect(result).toHaveLength(1)
|
||||
@ -232,11 +241,11 @@ describe('image-uploader utils', () => {
|
||||
|
||||
it('should resolve with file array with prefix for nested file', async () => {
|
||||
const mockFile: MockFile = { name: 'test.png' }
|
||||
const mockEntry = {
|
||||
const mockEntry = createMockEntry({
|
||||
isFile: true,
|
||||
isDirectory: false,
|
||||
file: (callback: FileCallback) => callback(mockFile),
|
||||
}
|
||||
})
|
||||
|
||||
const result = await traverseFileEntry(mockEntry, 'folder/')
|
||||
expect(result).toHaveLength(1)
|
||||
@ -244,24 +253,24 @@ describe('image-uploader utils', () => {
|
||||
})
|
||||
|
||||
it('should resolve empty array for unknown entry type', async () => {
|
||||
const mockEntry = {
|
||||
const mockEntry = createMockEntry({
|
||||
isFile: false,
|
||||
isDirectory: false,
|
||||
}
|
||||
})
|
||||
|
||||
const result = await traverseFileEntry(mockEntry)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle directory with no files', async () => {
|
||||
const mockEntry = {
|
||||
const mockEntry = createMockEntry({
|
||||
isFile: false,
|
||||
isDirectory: true,
|
||||
name: 'empty-folder',
|
||||
createReader: () => ({
|
||||
readEntries: (callback: EntriesCallback) => callback([]),
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
const result = await traverseFileEntry(mockEntry)
|
||||
expect(result).toEqual([])
|
||||
@ -271,20 +280,20 @@ describe('image-uploader utils', () => {
|
||||
const mockFile1: MockFile = { name: 'file1.png' }
|
||||
const mockFile2: MockFile = { name: 'file2.png' }
|
||||
|
||||
const mockFileEntry1 = {
|
||||
const mockFileEntry1 = createMockEntry({
|
||||
isFile: true,
|
||||
isDirectory: false,
|
||||
file: (callback: FileCallback) => callback(mockFile1),
|
||||
}
|
||||
})
|
||||
|
||||
const mockFileEntry2 = {
|
||||
const mockFileEntry2 = createMockEntry({
|
||||
isFile: true,
|
||||
isDirectory: false,
|
||||
file: (callback: FileCallback) => callback(mockFile2),
|
||||
}
|
||||
})
|
||||
|
||||
let readCount = 0
|
||||
const mockEntry = {
|
||||
const mockEntry = createMockEntry({
|
||||
isFile: false,
|
||||
isDirectory: true,
|
||||
name: 'folder',
|
||||
@ -292,14 +301,14 @@ describe('image-uploader utils', () => {
|
||||
readEntries: (callback: EntriesCallback) => {
|
||||
if (readCount === 0) {
|
||||
readCount++
|
||||
callback([mockFileEntry1, mockFileEntry2] as unknown as FileSystemEntry[])
|
||||
callback([mockFileEntry1, mockFileEntry2])
|
||||
}
|
||||
else {
|
||||
callback([])
|
||||
}
|
||||
},
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
const result = await traverseFileEntry(mockEntry)
|
||||
expect(result).toHaveLength(2)
|
||||
|
||||
@ -18,17 +18,17 @@ type FileWithPath = {
|
||||
relativePath?: string
|
||||
} & File
|
||||
|
||||
export const traverseFileEntry = (entry: any, prefix = ''): Promise<FileWithPath[]> => {
|
||||
export const traverseFileEntry = (entry: FileSystemEntry, prefix = ''): Promise<FileWithPath[]> => {
|
||||
return new Promise((resolve) => {
|
||||
if (entry.isFile) {
|
||||
entry.file((file: FileWithPath) => {
|
||||
(entry as FileSystemFileEntry).file((file: FileWithPath) => {
|
||||
file.relativePath = `${prefix}${file.name}`
|
||||
resolve([file])
|
||||
})
|
||||
}
|
||||
else if (entry.isDirectory) {
|
||||
const reader = entry.createReader()
|
||||
const entries: any[] = []
|
||||
const reader = (entry as FileSystemDirectoryEntry).createReader()
|
||||
const entries: FileSystemEntry[] = []
|
||||
const read = () => {
|
||||
reader.readEntries(async (results: FileSystemEntry[]) => {
|
||||
if (!results.length) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,218 @@
|
||||
'use client'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import {
|
||||
DSLImportMode,
|
||||
DSLImportStatus,
|
||||
} from '@/models/app'
|
||||
import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline'
|
||||
|
||||
export enum CreateFromDSLModalTab {
|
||||
FROM_FILE = 'from-file',
|
||||
FROM_URL = 'from-url',
|
||||
}
|
||||
|
||||
export type UseDSLImportOptions = {
|
||||
activeTab?: CreateFromDSLModalTab
|
||||
dslUrl?: string
|
||||
onSuccess?: () => void
|
||||
onClose?: () => void
|
||||
}
|
||||
|
||||
export type DSLVersions = {
|
||||
importedVersion: string
|
||||
systemVersion: string
|
||||
}
|
||||
|
||||
export const useDSLImport = ({
|
||||
activeTab = CreateFromDSLModalTab.FROM_FILE,
|
||||
dslUrl = '',
|
||||
onSuccess,
|
||||
onClose,
|
||||
}: UseDSLImportOptions) => {
|
||||
const { push } = useRouter()
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
|
||||
const [currentFile, setDSLFile] = useState<File>()
|
||||
const [fileContent, setFileContent] = useState<string>()
|
||||
const [currentTab, setCurrentTab] = useState(activeTab)
|
||||
const [dslUrlValue, setDslUrlValue] = useState(dslUrl)
|
||||
const [showConfirmModal, setShowConfirmModal] = useState(false)
|
||||
const [versions, setVersions] = useState<DSLVersions>()
|
||||
const [importId, setImportId] = useState<string>()
|
||||
const [isConfirming, setIsConfirming] = useState(false)
|
||||
|
||||
const { handleCheckPluginDependencies } = usePluginDependencies()
|
||||
const isCreatingRef = useRef(false)
|
||||
|
||||
const { mutateAsync: importDSL } = useImportPipelineDSL()
|
||||
const { mutateAsync: importDSLConfirm } = useImportPipelineDSLConfirm()
|
||||
|
||||
const readFile = useCallback((file: File) => {
|
||||
const reader = new FileReader()
|
||||
reader.onload = (event) => {
|
||||
const content = event.target?.result
|
||||
setFileContent(content as string)
|
||||
}
|
||||
reader.readAsText(file)
|
||||
}, [])
|
||||
|
||||
const handleFile = useCallback((file?: File) => {
|
||||
setDSLFile(file)
|
||||
if (file)
|
||||
readFile(file)
|
||||
if (!file)
|
||||
setFileContent('')
|
||||
}, [readFile])
|
||||
|
||||
const onCreate = useCallback(async () => {
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile)
|
||||
return
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_URL && !dslUrlValue)
|
||||
return
|
||||
if (isCreatingRef.current)
|
||||
return
|
||||
|
||||
isCreatingRef.current = true
|
||||
|
||||
let response
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_FILE) {
|
||||
response = await importDSL({
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: fileContent || '',
|
||||
})
|
||||
}
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_URL) {
|
||||
response = await importDSL({
|
||||
mode: DSLImportMode.YAML_URL,
|
||||
yaml_url: dslUrlValue || '',
|
||||
})
|
||||
}
|
||||
|
||||
if (!response) {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
isCreatingRef.current = false
|
||||
return
|
||||
}
|
||||
|
||||
const { id, status, pipeline_id, dataset_id, imported_dsl_version, current_dsl_version } = response
|
||||
|
||||
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
|
||||
onSuccess?.()
|
||||
onClose?.()
|
||||
|
||||
notify({
|
||||
type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning',
|
||||
message: t(status === DSLImportStatus.COMPLETED ? 'creation.successTip' : 'creation.caution', { ns: 'datasetPipeline' }),
|
||||
children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }),
|
||||
})
|
||||
|
||||
if (pipeline_id)
|
||||
await handleCheckPluginDependencies(pipeline_id, true)
|
||||
|
||||
push(`/datasets/${dataset_id}/pipeline`)
|
||||
isCreatingRef.current = false
|
||||
}
|
||||
else if (status === DSLImportStatus.PENDING) {
|
||||
setVersions({
|
||||
importedVersion: imported_dsl_version ?? '',
|
||||
systemVersion: current_dsl_version ?? '',
|
||||
})
|
||||
onClose?.()
|
||||
setTimeout(() => {
|
||||
setShowConfirmModal(true)
|
||||
}, 300)
|
||||
setImportId(id)
|
||||
isCreatingRef.current = false
|
||||
}
|
||||
else {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
isCreatingRef.current = false
|
||||
}
|
||||
}, [
|
||||
currentTab,
|
||||
currentFile,
|
||||
dslUrlValue,
|
||||
fileContent,
|
||||
importDSL,
|
||||
notify,
|
||||
t,
|
||||
onSuccess,
|
||||
onClose,
|
||||
handleCheckPluginDependencies,
|
||||
push,
|
||||
])
|
||||
|
||||
const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 })
|
||||
|
||||
const onDSLConfirm = useCallback(async () => {
|
||||
if (!importId)
|
||||
return
|
||||
|
||||
setIsConfirming(true)
|
||||
const response = await importDSLConfirm(importId)
|
||||
setIsConfirming(false)
|
||||
|
||||
if (!response) {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
return
|
||||
}
|
||||
|
||||
const { status, pipeline_id, dataset_id } = response
|
||||
|
||||
if (status === DSLImportStatus.COMPLETED) {
|
||||
onSuccess?.()
|
||||
setShowConfirmModal(false)
|
||||
|
||||
notify({
|
||||
type: 'success',
|
||||
message: t('creation.successTip', { ns: 'datasetPipeline' }),
|
||||
})
|
||||
|
||||
if (pipeline_id)
|
||||
await handleCheckPluginDependencies(pipeline_id, true)
|
||||
|
||||
push(`/datasets/${dataset_id}/pipeline`)
|
||||
}
|
||||
else if (status === DSLImportStatus.FAILED) {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
}
|
||||
}, [importId, importDSLConfirm, notify, t, onSuccess, handleCheckPluginDependencies, push])
|
||||
|
||||
const handleCancelConfirm = useCallback(() => {
|
||||
setShowConfirmModal(false)
|
||||
}, [])
|
||||
|
||||
const buttonDisabled = useMemo(() => {
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_FILE)
|
||||
return !currentFile
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_URL)
|
||||
return !dslUrlValue
|
||||
return false
|
||||
}, [currentTab, currentFile, dslUrlValue])
|
||||
|
||||
return {
|
||||
// State
|
||||
currentFile,
|
||||
currentTab,
|
||||
dslUrlValue,
|
||||
showConfirmModal,
|
||||
versions,
|
||||
buttonDisabled,
|
||||
isConfirming,
|
||||
|
||||
// Actions
|
||||
setCurrentTab,
|
||||
setDslUrlValue,
|
||||
handleFile,
|
||||
handleCreateApp,
|
||||
onDSLConfirm,
|
||||
handleCancelConfirm,
|
||||
}
|
||||
}
|
||||
@ -1,24 +1,18 @@
|
||||
'use client'
|
||||
import { useDebounceFn, useKeyPress } from 'ahooks'
|
||||
import { useKeyPress } from 'ahooks'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import {
|
||||
DSLImportMode,
|
||||
DSLImportStatus,
|
||||
} from '@/models/app'
|
||||
import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline'
|
||||
import DSLConfirmModal from './dsl-confirm-modal'
|
||||
import Header from './header'
|
||||
import { CreateFromDSLModalTab, useDSLImport } from './hooks/use-dsl-import'
|
||||
import Tab from './tab'
|
||||
import Uploader from './uploader'
|
||||
|
||||
export { CreateFromDSLModalTab }
|
||||
|
||||
type CreateFromDSLModalProps = {
|
||||
show: boolean
|
||||
onSuccess?: () => void
|
||||
@ -27,11 +21,6 @@ type CreateFromDSLModalProps = {
|
||||
dslUrl?: string
|
||||
}
|
||||
|
||||
export enum CreateFromDSLModalTab {
|
||||
FROM_FILE = 'from-file',
|
||||
FROM_URL = 'from-url',
|
||||
}
|
||||
|
||||
const CreateFromDSLModal = ({
|
||||
show,
|
||||
onSuccess,
|
||||
@ -39,149 +28,33 @@ const CreateFromDSLModal = ({
|
||||
activeTab = CreateFromDSLModalTab.FROM_FILE,
|
||||
dslUrl = '',
|
||||
}: CreateFromDSLModalProps) => {
|
||||
const { push } = useRouter()
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [currentFile, setDSLFile] = useState<File>()
|
||||
const [fileContent, setFileContent] = useState<string>()
|
||||
const [currentTab, setCurrentTab] = useState(activeTab)
|
||||
const [dslUrlValue, setDslUrlValue] = useState(dslUrl)
|
||||
const [showErrorModal, setShowErrorModal] = useState(false)
|
||||
const [versions, setVersions] = useState<{ importedVersion: string, systemVersion: string }>()
|
||||
const [importId, setImportId] = useState<string>()
|
||||
const { handleCheckPluginDependencies } = usePluginDependencies()
|
||||
|
||||
const readFile = (file: File) => {
|
||||
const reader = new FileReader()
|
||||
reader.onload = function (event) {
|
||||
const content = event.target?.result
|
||||
setFileContent(content as string)
|
||||
}
|
||||
reader.readAsText(file)
|
||||
}
|
||||
|
||||
const handleFile = (file?: File) => {
|
||||
setDSLFile(file)
|
||||
if (file)
|
||||
readFile(file)
|
||||
if (!file)
|
||||
setFileContent('')
|
||||
}
|
||||
|
||||
const isCreatingRef = useRef(false)
|
||||
|
||||
const { mutateAsync: importDSL } = useImportPipelineDSL()
|
||||
|
||||
const onCreate = async () => {
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile)
|
||||
return
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_URL && !dslUrlValue)
|
||||
return
|
||||
if (isCreatingRef.current)
|
||||
return
|
||||
isCreatingRef.current = true
|
||||
let response
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_FILE) {
|
||||
response = await importDSL({
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: fileContent || '',
|
||||
})
|
||||
}
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_URL) {
|
||||
response = await importDSL({
|
||||
mode: DSLImportMode.YAML_URL,
|
||||
yaml_url: dslUrlValue || '',
|
||||
})
|
||||
}
|
||||
|
||||
if (!response) {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
isCreatingRef.current = false
|
||||
return
|
||||
}
|
||||
const { id, status, pipeline_id, dataset_id, imported_dsl_version, current_dsl_version } = response
|
||||
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
if (onClose)
|
||||
onClose()
|
||||
|
||||
notify({
|
||||
type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning',
|
||||
message: t(status === DSLImportStatus.COMPLETED ? 'creation.successTip' : 'creation.caution', { ns: 'datasetPipeline' }),
|
||||
children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }),
|
||||
})
|
||||
if (pipeline_id)
|
||||
await handleCheckPluginDependencies(pipeline_id, true)
|
||||
push(`/datasets/${dataset_id}/pipeline`)
|
||||
isCreatingRef.current = false
|
||||
}
|
||||
else if (status === DSLImportStatus.PENDING) {
|
||||
setVersions({
|
||||
importedVersion: imported_dsl_version ?? '',
|
||||
systemVersion: current_dsl_version ?? '',
|
||||
})
|
||||
if (onClose)
|
||||
onClose()
|
||||
setTimeout(() => {
|
||||
setShowErrorModal(true)
|
||||
}, 300)
|
||||
setImportId(id)
|
||||
isCreatingRef.current = false
|
||||
}
|
||||
else {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
isCreatingRef.current = false
|
||||
}
|
||||
}
|
||||
|
||||
const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 })
|
||||
|
||||
useKeyPress('esc', () => {
|
||||
if (show && !showErrorModal)
|
||||
onClose()
|
||||
const {
|
||||
currentFile,
|
||||
currentTab,
|
||||
dslUrlValue,
|
||||
showConfirmModal,
|
||||
versions,
|
||||
buttonDisabled,
|
||||
isConfirming,
|
||||
setCurrentTab,
|
||||
setDslUrlValue,
|
||||
handleFile,
|
||||
handleCreateApp,
|
||||
onDSLConfirm,
|
||||
handleCancelConfirm,
|
||||
} = useDSLImport({
|
||||
activeTab,
|
||||
dslUrl,
|
||||
onSuccess,
|
||||
onClose,
|
||||
})
|
||||
|
||||
const { mutateAsync: importDSLConfirm } = useImportPipelineDSLConfirm()
|
||||
|
||||
const onDSLConfirm = async () => {
|
||||
if (!importId)
|
||||
return
|
||||
const response = await importDSLConfirm(importId)
|
||||
|
||||
if (!response) {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
return
|
||||
}
|
||||
|
||||
const { status, pipeline_id, dataset_id } = response
|
||||
|
||||
if (status === DSLImportStatus.COMPLETED) {
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
if (onClose)
|
||||
onClose()
|
||||
|
||||
notify({
|
||||
type: 'success',
|
||||
message: t('creation.successTip', { ns: 'datasetPipeline' }),
|
||||
})
|
||||
if (pipeline_id)
|
||||
await handleCheckPluginDependencies(pipeline_id, true)
|
||||
push(`datasets/${dataset_id}/pipeline`)
|
||||
}
|
||||
else if (status === DSLImportStatus.FAILED) {
|
||||
notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) })
|
||||
}
|
||||
}
|
||||
|
||||
const buttonDisabled = useMemo(() => {
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_FILE)
|
||||
return !currentFile
|
||||
if (currentTab === CreateFromDSLModalTab.FROM_URL)
|
||||
return !dslUrlValue
|
||||
return false
|
||||
}, [currentTab, currentFile, dslUrlValue])
|
||||
useKeyPress('esc', () => {
|
||||
if (show && !showConfirmModal)
|
||||
onClose()
|
||||
})
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -196,29 +69,25 @@ const CreateFromDSLModal = ({
|
||||
setCurrentTab={setCurrentTab}
|
||||
/>
|
||||
<div className="px-6 py-4">
|
||||
{
|
||||
currentTab === CreateFromDSLModalTab.FROM_FILE && (
|
||||
<Uploader
|
||||
className="mt-0"
|
||||
file={currentFile}
|
||||
updateFile={handleFile}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
currentTab === CreateFromDSLModalTab.FROM_URL && (
|
||||
<div>
|
||||
<div className="system-md-semibold leading6 mb-1 text-text-secondary">
|
||||
DSL URL
|
||||
</div>
|
||||
<Input
|
||||
placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''}
|
||||
value={dslUrlValue}
|
||||
onChange={e => setDslUrlValue(e.target.value)}
|
||||
/>
|
||||
{currentTab === CreateFromDSLModalTab.FROM_FILE && (
|
||||
<Uploader
|
||||
className="mt-0"
|
||||
file={currentFile}
|
||||
updateFile={handleFile}
|
||||
/>
|
||||
)}
|
||||
{currentTab === CreateFromDSLModalTab.FROM_URL && (
|
||||
<div>
|
||||
<div className="system-md-semibold leading6 mb-1 text-text-secondary">
|
||||
DSL URL
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<Input
|
||||
placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''}
|
||||
value={dslUrlValue}
|
||||
onChange={e => setDslUrlValue(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex justify-end gap-x-2 p-6 pt-5">
|
||||
<Button onClick={onClose}>
|
||||
@ -234,32 +103,14 @@ const CreateFromDSLModal = ({
|
||||
</Button>
|
||||
</div>
|
||||
</Modal>
|
||||
<Modal
|
||||
isShow={showErrorModal}
|
||||
onClose={() => setShowErrorModal(false)}
|
||||
className="w-[480px]"
|
||||
>
|
||||
<div className="flex flex-col items-start gap-2 self-stretch pb-4">
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
|
||||
<div className="system-md-regular flex grow flex-col text-text-secondary">
|
||||
<div>{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}</div>
|
||||
<div>{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}</div>
|
||||
<br />
|
||||
<div>
|
||||
{t('newApp.appCreateDSLErrorPart3', { ns: 'app' })}
|
||||
<span className="system-md-medium">{versions?.importedVersion}</span>
|
||||
</div>
|
||||
<div>
|
||||
{t('newApp.appCreateDSLErrorPart4', { ns: 'app' })}
|
||||
<span className="system-md-medium">{versions?.systemVersion}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start justify-end gap-2 self-stretch pt-6">
|
||||
<Button variant="secondary" onClick={() => setShowErrorModal(false)}>{t('newApp.Cancel', { ns: 'app' })}</Button>
|
||||
<Button variant="primary" destructive onClick={onDSLConfirm}>{t('newApp.Confirm', { ns: 'app' })}</Button>
|
||||
</div>
|
||||
</Modal>
|
||||
{showConfirmModal && (
|
||||
<DSLConfirmModal
|
||||
versions={versions}
|
||||
onCancel={handleCancelConfirm}
|
||||
onConfirm={onDSLConfirm}
|
||||
confirmDisabled={isConfirming}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -0,0 +1,334 @@
|
||||
import type { FileListItemProps } from './file-list-item'
|
||||
import type { CustomFile as File, FileItem } from '@/models/datasets'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants'
|
||||
import FileListItem from './file-list-item'
|
||||
|
||||
// Mock theme hook - can be changed per test
|
||||
let mockTheme = 'light'
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({ theme: mockTheme }),
|
||||
}))
|
||||
|
||||
// Mock theme types
|
||||
vi.mock('@/types/app', () => ({
|
||||
Theme: { dark: 'dark', light: 'light' },
|
||||
}))
|
||||
|
||||
// Mock SimplePieChart with dynamic import handling
|
||||
vi.mock('next/dynamic', () => ({
|
||||
default: () => {
|
||||
const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => (
|
||||
<div data-testid="pie-chart" data-percentage={percentage} data-stroke={stroke} data-fill={fill}>
|
||||
Pie Chart:
|
||||
{' '}
|
||||
{percentage}
|
||||
%
|
||||
</div>
|
||||
)
|
||||
DynamicComponent.displayName = 'SimplePieChart'
|
||||
return DynamicComponent
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock DocumentFileIcon
|
||||
vi.mock('@/app/components/datasets/common/document-file-icon', () => ({
|
||||
default: ({ name, extension, size }: { name: string, extension: string, size: string }) => (
|
||||
<div data-testid="document-icon" data-name={name} data-extension={extension} data-size={size}>
|
||||
Document Icon
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('FileListItem', () => {
|
||||
const createMockFile = (overrides: Partial<File> = {}): File => ({
|
||||
name: 'test-document.pdf',
|
||||
size: 1024 * 100, // 100KB
|
||||
type: 'application/pdf',
|
||||
lastModified: Date.now(),
|
||||
...overrides,
|
||||
} as File)
|
||||
|
||||
const createMockFileItem = (overrides: Partial<FileItem> = {}): FileItem => ({
|
||||
fileID: 'file-123',
|
||||
file: createMockFile(overrides.file as Partial<File>),
|
||||
progress: PROGRESS_NOT_STARTED,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const defaultProps: FileListItemProps = {
|
||||
fileItem: createMockFileItem(),
|
||||
onPreview: vi.fn(),
|
||||
onRemove: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTheme = 'light'
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render the file item container', () => {
|
||||
const { container } = render(<FileListItem {...defaultProps} />)
|
||||
const item = container.firstChild as HTMLElement
|
||||
expect(item).toHaveClass('flex', 'h-12', 'items-center', 'rounded-lg')
|
||||
})
|
||||
|
||||
it('should render document icon with correct props', () => {
|
||||
render(<FileListItem {...defaultProps} />)
|
||||
const icon = screen.getByTestId('document-icon')
|
||||
expect(icon).toBeInTheDocument()
|
||||
expect(icon).toHaveAttribute('data-name', 'test-document.pdf')
|
||||
expect(icon).toHaveAttribute('data-extension', 'pdf')
|
||||
expect(icon).toHaveAttribute('data-size', 'xl')
|
||||
})
|
||||
|
||||
it('should render file name', () => {
|
||||
render(<FileListItem {...defaultProps} />)
|
||||
expect(screen.getByText('test-document.pdf')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render file extension in uppercase via CSS class', () => {
|
||||
render(<FileListItem {...defaultProps} />)
|
||||
const extensionSpan = screen.getByText('pdf')
|
||||
expect(extensionSpan).toBeInTheDocument()
|
||||
expect(extensionSpan).toHaveClass('uppercase')
|
||||
})
|
||||
|
||||
it('should render file size', () => {
|
||||
render(<FileListItem {...defaultProps} />)
|
||||
// Default mock file is 100KB (1024 * 100 bytes)
|
||||
expect(screen.getByText('100.00 KB')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete button', () => {
|
||||
const { container } = render(<FileListItem {...defaultProps} />)
|
||||
const deleteButton = container.querySelector('.cursor-pointer')
|
||||
expect(deleteButton).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('progress states', () => {
|
||||
it('should show progress chart when uploading (0-99)', () => {
|
||||
const fileItem = createMockFileItem({ progress: 50 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
const pieChart = screen.getByTestId('pie-chart')
|
||||
expect(pieChart).toBeInTheDocument()
|
||||
expect(pieChart).toHaveAttribute('data-percentage', '50')
|
||||
})
|
||||
|
||||
it('should show progress chart at 0%', () => {
|
||||
const fileItem = createMockFileItem({ progress: 0 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
const pieChart = screen.getByTestId('pie-chart')
|
||||
expect(pieChart).toHaveAttribute('data-percentage', '0')
|
||||
})
|
||||
|
||||
it('should not show progress chart when complete (100)', () => {
|
||||
const fileItem = createMockFileItem({ progress: PROGRESS_COMPLETE })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show progress chart when not started (-1)', () => {
|
||||
const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('error state', () => {
|
||||
it('should show error indicator when progress is PROGRESS_ERROR', () => {
|
||||
const fileItem = createMockFileItem({ progress: PROGRESS_ERROR })
|
||||
const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
const errorIndicator = container.querySelector('.text-text-destructive')
|
||||
expect(errorIndicator).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show error indicator when not in error state', () => {
|
||||
const { container } = render(<FileListItem {...defaultProps} />)
|
||||
const errorIndicator = container.querySelector('.text-text-destructive')
|
||||
expect(errorIndicator).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('theme handling', () => {
|
||||
it('should use correct chart color for light theme', () => {
|
||||
mockTheme = 'light'
|
||||
const fileItem = createMockFileItem({ progress: 50 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
const pieChart = screen.getByTestId('pie-chart')
|
||||
expect(pieChart).toHaveAttribute('data-stroke', '#296dff')
|
||||
expect(pieChart).toHaveAttribute('data-fill', '#296dff')
|
||||
})
|
||||
|
||||
it('should use correct chart color for dark theme', () => {
|
||||
mockTheme = 'dark'
|
||||
const fileItem = createMockFileItem({ progress: 50 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
const pieChart = screen.getByTestId('pie-chart')
|
||||
expect(pieChart).toHaveAttribute('data-stroke', '#5289ff')
|
||||
expect(pieChart).toHaveAttribute('data-fill', '#5289ff')
|
||||
})
|
||||
})
|
||||
|
||||
describe('event handlers', () => {
|
||||
it('should call onPreview when item is clicked with file id', () => {
|
||||
const onPreview = vi.fn()
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ id: 'uploaded-id' } as Partial<File>),
|
||||
})
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} />)
|
||||
|
||||
const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')!
|
||||
fireEvent.click(item)
|
||||
|
||||
expect(onPreview).toHaveBeenCalledTimes(1)
|
||||
expect(onPreview).toHaveBeenCalledWith(fileItem.file)
|
||||
})
|
||||
|
||||
it('should not call onPreview when file has no id', () => {
|
||||
const onPreview = vi.fn()
|
||||
const fileItem = createMockFileItem()
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} />)
|
||||
|
||||
const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')!
|
||||
fireEvent.click(item)
|
||||
|
||||
expect(onPreview).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onRemove when delete button is clicked', () => {
|
||||
const onRemove = vi.fn()
|
||||
const fileItem = createMockFileItem()
|
||||
const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} onRemove={onRemove} />)
|
||||
|
||||
const deleteButton = container.querySelector('.cursor-pointer')!
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onRemove).toHaveBeenCalledTimes(1)
|
||||
expect(onRemove).toHaveBeenCalledWith('file-123')
|
||||
})
|
||||
|
||||
it('should stop propagation when delete button is clicked', () => {
|
||||
const onPreview = vi.fn()
|
||||
const onRemove = vi.fn()
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ id: 'uploaded-id' } as Partial<File>),
|
||||
})
|
||||
const { container } = render(<FileListItem {...defaultProps} fileItem={fileItem} onPreview={onPreview} onRemove={onRemove} />)
|
||||
|
||||
const deleteButton = container.querySelector('.cursor-pointer')!
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onRemove).toHaveBeenCalledTimes(1)
|
||||
expect(onPreview).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('file type handling', () => {
|
||||
it('should handle files with multiple dots in name', () => {
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ name: 'my.document.file.docx' }),
|
||||
})
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
expect(screen.getByText('my.document.file.docx')).toBeInTheDocument()
|
||||
expect(screen.getByText('docx')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle files without extension', () => {
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ name: 'README' }),
|
||||
})
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
// File name appears once, and extension area shows empty string
|
||||
expect(screen.getByText('README')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle various file extensions', () => {
|
||||
const extensions = ['txt', 'md', 'json', 'csv', 'xlsx']
|
||||
|
||||
extensions.forEach((ext) => {
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ name: `file.${ext}` }),
|
||||
})
|
||||
const { unmount } = render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
expect(screen.getByText(ext)).toBeInTheDocument()
|
||||
unmount()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('file size display', () => {
|
||||
it('should display size in KB for small files', () => {
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ size: 5 * 1024 }),
|
||||
})
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
expect(screen.getByText('5.00 KB')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display size in MB for larger files', () => {
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ size: 5 * 1024 * 1024 }),
|
||||
})
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
expect(screen.getByText('5.00 MB')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('upload progress values', () => {
|
||||
it('should show chart at progress 1', () => {
|
||||
const fileItem = createMockFileItem({ progress: 1 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
expect(screen.getByTestId('pie-chart')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show chart at progress 99', () => {
|
||||
const fileItem = createMockFileItem({ progress: 99 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
expect(screen.getByTestId('pie-chart')).toHaveAttribute('data-percentage', '99')
|
||||
})
|
||||
|
||||
it('should not show chart at progress 100', () => {
|
||||
const fileItem = createMockFileItem({ progress: 100 })
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('styling', () => {
|
||||
it('should have proper shadow styling', () => {
|
||||
const { container } = render(<FileListItem {...defaultProps} />)
|
||||
const item = container.firstChild as HTMLElement
|
||||
expect(item).toHaveClass('shadow-xs')
|
||||
})
|
||||
|
||||
it('should have proper border styling', () => {
|
||||
const { container } = render(<FileListItem {...defaultProps} />)
|
||||
const item = container.firstChild as HTMLElement
|
||||
expect(item).toHaveClass('border', 'border-components-panel-border')
|
||||
})
|
||||
|
||||
it('should truncate long file names', () => {
|
||||
const longFileName = 'this-is-a-very-long-file-name-that-should-be-truncated.pdf'
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ name: longFileName }),
|
||||
})
|
||||
render(<FileListItem {...defaultProps} fileItem={fileItem} />)
|
||||
|
||||
const nameElement = screen.getByText(longFileName)
|
||||
expect(nameElement).toHaveClass('truncate')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,89 @@
|
||||
'use client'
|
||||
import type { CustomFile as File, FileItem } from '@/models/datasets'
|
||||
import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react'
|
||||
import dynamic from 'next/dynamic'
|
||||
import { useMemo } from 'react'
|
||||
import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import { formatFileSize, getFileExtension } from '@/utils/format'
|
||||
import { PROGRESS_COMPLETE, PROGRESS_ERROR } from '../constants'
|
||||
|
||||
const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false })
|
||||
|
||||
export type FileListItemProps = {
|
||||
fileItem: FileItem
|
||||
onPreview: (file: File) => void
|
||||
onRemove: (fileID: string) => void
|
||||
}
|
||||
|
||||
const FileListItem = ({
|
||||
fileItem,
|
||||
onPreview,
|
||||
onRemove,
|
||||
}: FileListItemProps) => {
|
||||
const { theme } = useTheme()
|
||||
const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme])
|
||||
|
||||
const isUploading = fileItem.progress >= 0 && fileItem.progress < PROGRESS_COMPLETE
|
||||
const isError = fileItem.progress === PROGRESS_ERROR
|
||||
|
||||
const handleClick = () => {
|
||||
if (fileItem.file?.id)
|
||||
onPreview(fileItem.file)
|
||||
}
|
||||
|
||||
const handleRemove = (e: React.MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
onRemove(fileItem.fileID)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
onClick={handleClick}
|
||||
className="flex h-12 max-w-[640px] items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary shadow-xs"
|
||||
>
|
||||
<div className="flex w-12 shrink-0 items-center justify-center">
|
||||
<DocumentFileIcon
|
||||
size="xl"
|
||||
className="shrink-0"
|
||||
name={fileItem.file.name}
|
||||
extension={getFileExtension(fileItem.file.name)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex shrink grow flex-col gap-0.5">
|
||||
<div className="flex w-full">
|
||||
<div className="w-0 grow truncate text-sm leading-4 text-text-secondary">
|
||||
{fileItem.file.name}
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full truncate leading-3 text-text-tertiary">
|
||||
<span className="uppercase">{getFileExtension(fileItem.file.name)}</span>
|
||||
<span className="px-1 text-text-quaternary">·</span>
|
||||
<span>{formatFileSize(fileItem.file.size)}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex w-16 shrink-0 items-center justify-end gap-1 pr-3">
|
||||
{isUploading && (
|
||||
<SimplePieChart
|
||||
percentage={fileItem.progress}
|
||||
stroke={chartColor}
|
||||
fill={chartColor}
|
||||
animationDuration={0}
|
||||
/>
|
||||
)}
|
||||
{isError && (
|
||||
<RiErrorWarningFill className="size-4 text-text-destructive" />
|
||||
)}
|
||||
<span
|
||||
className="flex h-6 w-6 cursor-pointer items-center justify-center"
|
||||
onClick={handleRemove}
|
||||
>
|
||||
<RiDeleteBinLine className="size-4 text-text-tertiary" />
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default FileListItem
|
||||
@ -0,0 +1,210 @@
|
||||
import type { RefObject } from 'react'
|
||||
import type { UploadDropzoneProps } from './upload-dropzone'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import UploadDropzone from './upload-dropzone'
|
||||
|
||||
// Helper to create mock ref objects for testing
|
||||
const createMockRef = <T,>(value: T | null = null): RefObject<T | null> => ({ current: value })
|
||||
|
||||
// Mock react-i18next
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: Record<string, unknown>) => {
|
||||
const translations: Record<string, string> = {
|
||||
'stepOne.uploader.button': 'Drag and drop files, or',
|
||||
'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or',
|
||||
'stepOne.uploader.browse': 'Browse',
|
||||
'stepOne.uploader.tip': 'Supports {{supportTypes}}, Max {{size}}MB each, up to {{batchCount}} files at a time, {{totalCount}} files total',
|
||||
}
|
||||
let result = translations[key] || key
|
||||
if (options && typeof options === 'object') {
|
||||
Object.entries(options).forEach(([k, v]) => {
|
||||
result = result.replace(`{{${k}}}`, String(v))
|
||||
})
|
||||
}
|
||||
return result
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('UploadDropzone', () => {
|
||||
const defaultProps: UploadDropzoneProps = {
|
||||
dropRef: createMockRef<HTMLDivElement>() as RefObject<HTMLDivElement | null>,
|
||||
dragRef: createMockRef<HTMLDivElement>() as RefObject<HTMLDivElement | null>,
|
||||
fileUploaderRef: createMockRef<HTMLInputElement>() as RefObject<HTMLInputElement | null>,
|
||||
dragging: false,
|
||||
supportBatchUpload: true,
|
||||
supportTypesShowNames: 'PDF, DOCX, TXT',
|
||||
fileUploadConfig: {
|
||||
file_size_limit: 15,
|
||||
batch_count_limit: 5,
|
||||
file_upload_limit: 10,
|
||||
},
|
||||
acceptTypes: ['.pdf', '.docx', '.txt'],
|
||||
onSelectFile: vi.fn(),
|
||||
onFileChange: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render the dropzone container', () => {
|
||||
const { container } = render(<UploadDropzone {...defaultProps} />)
|
||||
const dropzone = container.querySelector('[class*="border-dashed"]')
|
||||
expect(dropzone).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render hidden file input', () => {
|
||||
render(<UploadDropzone {...defaultProps} />)
|
||||
const input = document.getElementById('fileUploader') as HTMLInputElement
|
||||
expect(input).toBeInTheDocument()
|
||||
expect(input).toHaveClass('hidden')
|
||||
expect(input).toHaveAttribute('type', 'file')
|
||||
})
|
||||
|
||||
it('should render upload icon', () => {
|
||||
render(<UploadDropzone {...defaultProps} />)
|
||||
const icon = document.querySelector('svg')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render browse label when extensions are allowed', () => {
|
||||
render(<UploadDropzone {...defaultProps} />)
|
||||
expect(screen.getByText('Browse')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render browse label when no extensions allowed', () => {
|
||||
render(<UploadDropzone {...defaultProps} acceptTypes={[]} />)
|
||||
expect(screen.queryByText('Browse')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render file size and count limits', () => {
|
||||
render(<UploadDropzone {...defaultProps} />)
|
||||
const tipText = screen.getByText(/Supports.*Max.*15MB/i)
|
||||
expect(tipText).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('file input configuration', () => {
|
||||
it('should allow multiple files when supportBatchUpload is true', () => {
|
||||
render(<UploadDropzone {...defaultProps} supportBatchUpload={true} />)
|
||||
const input = document.getElementById('fileUploader') as HTMLInputElement
|
||||
expect(input).toHaveAttribute('multiple')
|
||||
})
|
||||
|
||||
it('should not allow multiple files when supportBatchUpload is false', () => {
|
||||
render(<UploadDropzone {...defaultProps} supportBatchUpload={false} />)
|
||||
const input = document.getElementById('fileUploader') as HTMLInputElement
|
||||
expect(input).not.toHaveAttribute('multiple')
|
||||
})
|
||||
|
||||
it('should set accept attribute with correct types', () => {
|
||||
render(<UploadDropzone {...defaultProps} acceptTypes={['.pdf', '.docx']} />)
|
||||
const input = document.getElementById('fileUploader') as HTMLInputElement
|
||||
expect(input).toHaveAttribute('accept', '.pdf,.docx')
|
||||
})
|
||||
})
|
||||
|
||||
describe('text content', () => {
|
||||
it('should show batch upload text when supportBatchUpload is true', () => {
|
||||
render(<UploadDropzone {...defaultProps} supportBatchUpload={true} />)
|
||||
expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show single file text when supportBatchUpload is false', () => {
|
||||
render(<UploadDropzone {...defaultProps} supportBatchUpload={false} />)
|
||||
expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('dragging state', () => {
|
||||
it('should apply dragging styles when dragging is true', () => {
|
||||
const { container } = render(<UploadDropzone {...defaultProps} dragging={true} />)
|
||||
const dropzone = container.querySelector('[class*="border-components-dropzone-border-accent"]')
|
||||
expect(dropzone).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render drag overlay when dragging', () => {
|
||||
const dragRef = createMockRef<HTMLDivElement>()
|
||||
render(<UploadDropzone {...defaultProps} dragging={true} dragRef={dragRef as RefObject<HTMLDivElement | null>} />)
|
||||
const overlay = document.querySelector('.absolute.left-0.top-0')
|
||||
expect(overlay).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render drag overlay when not dragging', () => {
|
||||
render(<UploadDropzone {...defaultProps} dragging={false} />)
|
||||
const overlay = document.querySelector('.absolute.left-0.top-0')
|
||||
expect(overlay).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('event handlers', () => {
|
||||
it('should call onSelectFile when browse label is clicked', () => {
|
||||
const onSelectFile = vi.fn()
|
||||
render(<UploadDropzone {...defaultProps} onSelectFile={onSelectFile} />)
|
||||
|
||||
const browseLabel = screen.getByText('Browse')
|
||||
fireEvent.click(browseLabel)
|
||||
|
||||
expect(onSelectFile).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onFileChange when files are selected', () => {
|
||||
const onFileChange = vi.fn()
|
||||
render(<UploadDropzone {...defaultProps} onFileChange={onFileChange} />)
|
||||
|
||||
const input = document.getElementById('fileUploader') as HTMLInputElement
|
||||
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
|
||||
fireEvent.change(input, { target: { files: [file] } })
|
||||
|
||||
expect(onFileChange).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('refs', () => {
|
||||
it('should attach dropRef to drop container', () => {
|
||||
const dropRef = createMockRef<HTMLDivElement>()
|
||||
render(<UploadDropzone {...defaultProps} dropRef={dropRef as RefObject<HTMLDivElement | null>} />)
|
||||
expect(dropRef.current).toBeInstanceOf(HTMLDivElement)
|
||||
})
|
||||
|
||||
it('should attach fileUploaderRef to input element', () => {
|
||||
const fileUploaderRef = createMockRef<HTMLInputElement>()
|
||||
render(<UploadDropzone {...defaultProps} fileUploaderRef={fileUploaderRef as RefObject<HTMLInputElement | null>} />)
|
||||
expect(fileUploaderRef.current).toBeInstanceOf(HTMLInputElement)
|
||||
})
|
||||
|
||||
it('should attach dragRef to overlay when dragging', () => {
|
||||
const dragRef = createMockRef<HTMLDivElement>()
|
||||
render(<UploadDropzone {...defaultProps} dragging={true} dragRef={dragRef as RefObject<HTMLDivElement | null>} />)
|
||||
expect(dragRef.current).toBeInstanceOf(HTMLDivElement)
|
||||
})
|
||||
})
|
||||
|
||||
describe('styling', () => {
|
||||
it('should have base dropzone styling', () => {
|
||||
const { container } = render(<UploadDropzone {...defaultProps} />)
|
||||
const dropzone = container.querySelector('[class*="border-dashed"]')
|
||||
expect(dropzone).toBeInTheDocument()
|
||||
expect(dropzone).toHaveClass('rounded-xl')
|
||||
})
|
||||
|
||||
it('should have cursor-pointer on browse label', () => {
|
||||
render(<UploadDropzone {...defaultProps} />)
|
||||
const browseLabel = screen.getByText('Browse')
|
||||
expect(browseLabel).toHaveClass('cursor-pointer')
|
||||
})
|
||||
})
|
||||
|
||||
describe('accessibility', () => {
|
||||
it('should have an accessible file input', () => {
|
||||
render(<UploadDropzone {...defaultProps} />)
|
||||
const input = document.getElementById('fileUploader') as HTMLInputElement
|
||||
expect(input).toHaveAttribute('id', 'fileUploader')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,84 @@
|
||||
'use client'
|
||||
import type { RefObject } from 'react'
|
||||
import type { FileUploadConfig } from '../hooks/use-file-upload'
|
||||
import { RiUploadCloud2Line } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export type UploadDropzoneProps = {
|
||||
dropRef: RefObject<HTMLDivElement | null>
|
||||
dragRef: RefObject<HTMLDivElement | null>
|
||||
fileUploaderRef: RefObject<HTMLInputElement | null>
|
||||
dragging: boolean
|
||||
supportBatchUpload: boolean
|
||||
supportTypesShowNames: string
|
||||
fileUploadConfig: FileUploadConfig
|
||||
acceptTypes: string[]
|
||||
onSelectFile: () => void
|
||||
onFileChange: (e: React.ChangeEvent<HTMLInputElement>) => void
|
||||
}
|
||||
|
||||
const UploadDropzone = ({
|
||||
dropRef,
|
||||
dragRef,
|
||||
fileUploaderRef,
|
||||
dragging,
|
||||
supportBatchUpload,
|
||||
supportTypesShowNames,
|
||||
fileUploadConfig,
|
||||
acceptTypes,
|
||||
onSelectFile,
|
||||
onFileChange,
|
||||
}: UploadDropzoneProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
ref={fileUploaderRef}
|
||||
id="fileUploader"
|
||||
className="hidden"
|
||||
type="file"
|
||||
multiple={supportBatchUpload}
|
||||
accept={acceptTypes.join(',')}
|
||||
onChange={onFileChange}
|
||||
/>
|
||||
<div
|
||||
ref={dropRef}
|
||||
className={cn(
|
||||
'relative mb-2 box-border flex min-h-20 max-w-[640px] flex-col items-center justify-center gap-1 rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg px-4 py-3 text-xs leading-4 text-text-tertiary',
|
||||
dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent',
|
||||
)}
|
||||
>
|
||||
<div className="flex min-h-5 items-center justify-center text-sm leading-4 text-text-secondary">
|
||||
<RiUploadCloud2Line className="mr-2 size-5" />
|
||||
<span>
|
||||
{supportBatchUpload
|
||||
? t('stepOne.uploader.button', { ns: 'datasetCreation' })
|
||||
: t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })}
|
||||
{acceptTypes.length > 0 && (
|
||||
<label
|
||||
className="ml-1 cursor-pointer text-text-accent"
|
||||
onClick={onSelectFile}
|
||||
>
|
||||
{t('stepOne.uploader.browse', { ns: 'datasetCreation' })}
|
||||
</label>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
{t('stepOne.uploader.tip', {
|
||||
ns: 'datasetCreation',
|
||||
size: fileUploadConfig.file_size_limit,
|
||||
supportTypes: supportTypesShowNames,
|
||||
batchCount: fileUploadConfig.batch_count_limit,
|
||||
totalCount: fileUploadConfig.file_upload_limit,
|
||||
})}
|
||||
</div>
|
||||
{dragging && <div ref={dragRef} className="absolute left-0 top-0 h-full w-full" />}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default UploadDropzone
|
||||
@ -0,0 +1,3 @@
|
||||
export const PROGRESS_NOT_STARTED = -1
|
||||
export const PROGRESS_ERROR = -2
|
||||
export const PROGRESS_COMPLETE = 100
|
||||
@ -0,0 +1,921 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { CustomFile, FileItem } from '@/models/datasets'
|
||||
import { act, render, renderHook, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
|
||||
import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants'
|
||||
// Import after mocks
|
||||
import { useFileUpload } from './use-file-upload'
|
||||
|
||||
// Mock notify function
|
||||
const mockNotify = vi.fn()
|
||||
const mockClose = vi.fn()
|
||||
|
||||
// Mock ToastContext
|
||||
vi.mock('use-context-selector', async () => {
|
||||
const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector')
|
||||
return {
|
||||
...actual,
|
||||
useContext: vi.fn(() => ({ notify: mockNotify, close: mockClose })),
|
||||
}
|
||||
})
|
||||
|
||||
// Mock upload service
|
||||
const mockUpload = vi.fn()
|
||||
vi.mock('@/service/base', () => ({
|
||||
upload: (...args: unknown[]) => mockUpload(...args),
|
||||
}))
|
||||
|
||||
// Mock file upload config
|
||||
const mockFileUploadConfig = {
|
||||
file_size_limit: 15,
|
||||
batch_count_limit: 5,
|
||||
file_upload_limit: 10,
|
||||
}
|
||||
|
||||
const mockSupportTypes = {
|
||||
allowed_extensions: ['pdf', 'docx', 'txt', 'md'],
|
||||
}
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useFileUploadConfig: () => ({ data: mockFileUploadConfig }),
|
||||
useFileSupportTypes: () => ({ data: mockSupportTypes }),
|
||||
}))
|
||||
|
||||
// Mock i18n
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock locale
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: () => 'en-US',
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config/language', () => ({
|
||||
LanguagesSupported: ['en-US', 'zh-Hans'],
|
||||
}))
|
||||
|
||||
// Mock config
|
||||
vi.mock('@/config', () => ({
|
||||
IS_CE_EDITION: false,
|
||||
}))
|
||||
|
||||
// Mock file upload error message
|
||||
vi.mock('@/app/components/base/file-uploader/utils', () => ({
|
||||
getFileUploadErrorMessage: (_e: unknown, defaultMsg: string) => defaultMsg,
|
||||
}))
|
||||
|
||||
const createWrapper = () => {
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
{children}
|
||||
</ToastContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('useFileUpload', () => {
|
||||
const defaultOptions = {
|
||||
fileList: [] as FileItem[],
|
||||
prepareFileList: vi.fn(),
|
||||
onFileUpdate: vi.fn(),
|
||||
onFileListUpdate: vi.fn(),
|
||||
onPreview: vi.fn(),
|
||||
supportBatchUpload: true,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUpload.mockReset()
|
||||
// Default mock to return a resolved promise to avoid unhandled rejections
|
||||
mockUpload.mockResolvedValue({ id: 'default-id' })
|
||||
mockNotify.mockReset()
|
||||
})
|
||||
|
||||
describe('initialization', () => {
|
||||
it('should initialize with default values', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(result.current.dragging).toBe(false)
|
||||
expect(result.current.hideUpload).toBe(false)
|
||||
expect(result.current.dropRef.current).toBeNull()
|
||||
expect(result.current.dragRef.current).toBeNull()
|
||||
expect(result.current.fileUploaderRef.current).toBeNull()
|
||||
})
|
||||
|
||||
it('should set hideUpload true when not batch upload and has files', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({
|
||||
...defaultOptions,
|
||||
supportBatchUpload: false,
|
||||
fileList: [{ fileID: 'file-1', file: {} as CustomFile, progress: 100 }],
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(result.current.hideUpload).toBe(true)
|
||||
})
|
||||
|
||||
it('should compute acceptTypes correctly', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(result.current.acceptTypes).toEqual(['.pdf', '.docx', '.txt', '.md'])
|
||||
})
|
||||
|
||||
it('should compute supportTypesShowNames correctly', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(result.current.supportTypesShowNames).toContain('PDF')
|
||||
expect(result.current.supportTypesShowNames).toContain('DOCX')
|
||||
expect(result.current.supportTypesShowNames).toContain('TXT')
|
||||
// 'md' is mapped to 'markdown' in the extensionMap
|
||||
expect(result.current.supportTypesShowNames).toContain('MARKDOWN')
|
||||
})
|
||||
|
||||
it('should set batch limit to 1 when not batch upload', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({
|
||||
...defaultOptions,
|
||||
supportBatchUpload: false,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(result.current.fileUploadConfig.batch_count_limit).toBe(1)
|
||||
expect(result.current.fileUploadConfig.file_upload_limit).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('selectHandle', () => {
|
||||
it('should trigger click on file input', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockClick = vi.fn()
|
||||
const mockInput = { click: mockClick } as unknown as HTMLInputElement
|
||||
Object.defineProperty(result.current.fileUploaderRef, 'current', {
|
||||
value: mockInput,
|
||||
writable: true,
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.selectHandle()
|
||||
})
|
||||
|
||||
expect(mockClick).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should do nothing when file input ref is null', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(() => {
|
||||
act(() => {
|
||||
result.current.selectHandle()
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('handlePreview', () => {
|
||||
it('should call onPreview when file has id', () => {
|
||||
const onPreview = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onPreview }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = { id: 'file-123', name: 'test.pdf', size: 1024 } as CustomFile
|
||||
|
||||
act(() => {
|
||||
result.current.handlePreview(mockFile)
|
||||
})
|
||||
|
||||
expect(onPreview).toHaveBeenCalledWith(mockFile)
|
||||
})
|
||||
|
||||
it('should not call onPreview when file has no id', () => {
|
||||
const onPreview = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onPreview }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = { name: 'test.pdf', size: 1024 } as CustomFile
|
||||
|
||||
act(() => {
|
||||
result.current.handlePreview(mockFile)
|
||||
})
|
||||
|
||||
expect(onPreview).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('removeFile', () => {
|
||||
it('should call onFileListUpdate with filtered list', () => {
|
||||
const onFileListUpdate = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onFileListUpdate }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.removeFile('file-to-remove')
|
||||
})
|
||||
|
||||
expect(onFileListUpdate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should clear file input value', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockInput = { value: 'some-file' } as HTMLInputElement
|
||||
Object.defineProperty(result.current.fileUploaderRef, 'current', {
|
||||
value: mockInput,
|
||||
writable: true,
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.removeFile('file-123')
|
||||
})
|
||||
|
||||
expect(mockInput.value).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('fileChangeHandle', () => {
|
||||
it('should handle valid files', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id' })
|
||||
|
||||
const prepareFileList = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, prepareFileList }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(prepareFileList).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should limit files to batch count', () => {
|
||||
const prepareFileList = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, prepareFileList }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const files = Array.from({ length: 10 }, (_, i) =>
|
||||
new File(['content'], `file${i}.pdf`, { type: 'application/pdf' }))
|
||||
|
||||
const event = {
|
||||
target: { files },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
// Should be called with at most batch_count_limit files
|
||||
if (prepareFileList.mock.calls.length > 0) {
|
||||
const calledFiles = prepareFileList.mock.calls[0][0]
|
||||
expect(calledFiles.length).toBeLessThanOrEqual(mockFileUploadConfig.batch_count_limit)
|
||||
}
|
||||
})
|
||||
|
||||
it('should reject invalid file types', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should reject files exceeding size limit', () => {
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload(defaultOptions),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
// Create a file larger than the limit (15MB)
|
||||
const largeFile = new File([new ArrayBuffer(20 * 1024 * 1024)], 'large.pdf', { type: 'application/pdf' })
|
||||
|
||||
const event = {
|
||||
target: { files: [largeFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle null files', () => {
|
||||
const prepareFileList = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, prepareFileList }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const event = {
|
||||
target: { files: null },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
expect(prepareFileList).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('drag and drop handlers', () => {
|
||||
const TestDropzone = ({ options }: { options: typeof defaultOptions }) => {
|
||||
const {
|
||||
dropRef,
|
||||
dragRef,
|
||||
dragging,
|
||||
} = useFileUpload(options)
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div ref={dropRef} data-testid="dropzone">
|
||||
{dragging && <div ref={dragRef} data-testid="drag-overlay" />}
|
||||
</div>
|
||||
<span data-testid="dragging">{String(dragging)}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
it('should set dragging true on dragenter', async () => {
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={defaultOptions} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
|
||||
dropzone.dispatchEvent(dragEnterEvent)
|
||||
})
|
||||
|
||||
expect(getByTestId('dragging').textContent).toBe('true')
|
||||
})
|
||||
|
||||
it('should handle dragover event', async () => {
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={defaultOptions} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true })
|
||||
dropzone.dispatchEvent(dragOverEvent)
|
||||
})
|
||||
|
||||
expect(dropzone).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should set dragging false on dragleave from drag overlay', async () => {
|
||||
const { getByTestId, queryByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={defaultOptions} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
|
||||
dropzone.dispatchEvent(dragEnterEvent)
|
||||
})
|
||||
|
||||
expect(getByTestId('dragging').textContent).toBe('true')
|
||||
|
||||
const dragOverlay = queryByTestId('drag-overlay')
|
||||
if (dragOverlay) {
|
||||
await act(async () => {
|
||||
const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true })
|
||||
Object.defineProperty(dragLeaveEvent, 'target', { value: dragOverlay })
|
||||
dropzone.dispatchEvent(dragLeaveEvent)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle drop with files', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id' })
|
||||
const prepareFileList = vi.fn()
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
|
||||
await act(async () => {
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', {
|
||||
value: {
|
||||
items: [{
|
||||
getAsFile: () => mockFile,
|
||||
webkitGetAsEntry: () => null,
|
||||
}],
|
||||
},
|
||||
})
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(prepareFileList).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle drop without dataTransfer', async () => {
|
||||
const prepareFileList = vi.fn()
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', { value: null })
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
expect(prepareFileList).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should limit to single file on drop when supportBatchUpload is false', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id' })
|
||||
const prepareFileList = vi.fn()
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, supportBatchUpload: false, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
const files = [
|
||||
new File(['content1'], 'test1.pdf', { type: 'application/pdf' }),
|
||||
new File(['content2'], 'test2.pdf', { type: 'application/pdf' }),
|
||||
]
|
||||
|
||||
await act(async () => {
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', {
|
||||
value: {
|
||||
items: files.map(f => ({
|
||||
getAsFile: () => f,
|
||||
webkitGetAsEntry: () => null,
|
||||
})),
|
||||
},
|
||||
})
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
if (prepareFileList.mock.calls.length > 0) {
|
||||
const calledFiles = prepareFileList.mock.calls[0][0]
|
||||
expect(calledFiles.length).toBe(1)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle drop with FileSystemFileEntry', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id' })
|
||||
const prepareFileList = vi.fn()
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', {
|
||||
value: {
|
||||
items: [{
|
||||
getAsFile: () => mockFile,
|
||||
webkitGetAsEntry: () => ({
|
||||
isFile: true,
|
||||
isDirectory: false,
|
||||
file: (callback: (file: File) => void) => callback(mockFile),
|
||||
}),
|
||||
}],
|
||||
},
|
||||
})
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(prepareFileList).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle drop with FileSystemDirectoryEntry', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id' })
|
||||
const prepareFileList = vi.fn()
|
||||
const mockFile = new File(['content'], 'nested.pdf', { type: 'application/pdf' })
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
let callCount = 0
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', {
|
||||
value: {
|
||||
items: [{
|
||||
getAsFile: () => null,
|
||||
webkitGetAsEntry: () => ({
|
||||
isFile: false,
|
||||
isDirectory: true,
|
||||
name: 'folder',
|
||||
createReader: () => ({
|
||||
readEntries: (callback: (entries: Array<{ isFile: boolean, isDirectory: boolean, name?: string, file?: (cb: (f: File) => void) => void }>) => void) => {
|
||||
// First call returns file entry, second call returns empty (signals end)
|
||||
if (callCount === 0) {
|
||||
callCount++
|
||||
callback([{
|
||||
isFile: true,
|
||||
isDirectory: false,
|
||||
name: 'nested.pdf',
|
||||
file: (cb: (f: File) => void) => cb(mockFile),
|
||||
}])
|
||||
}
|
||||
else {
|
||||
callback([])
|
||||
}
|
||||
},
|
||||
}),
|
||||
}),
|
||||
}],
|
||||
},
|
||||
})
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(prepareFileList).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle drop with empty directory', async () => {
|
||||
const prepareFileList = vi.fn()
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', {
|
||||
value: {
|
||||
items: [{
|
||||
getAsFile: () => null,
|
||||
webkitGetAsEntry: () => ({
|
||||
isFile: false,
|
||||
isDirectory: true,
|
||||
name: 'empty-folder',
|
||||
createReader: () => ({
|
||||
readEntries: (callback: (entries: never[]) => void) => {
|
||||
callback([])
|
||||
},
|
||||
}),
|
||||
}),
|
||||
}],
|
||||
},
|
||||
})
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
// Should not prepare file list if no valid files
|
||||
await new Promise(resolve => setTimeout(resolve, 100))
|
||||
})
|
||||
|
||||
it('should handle entry that is neither file nor directory', async () => {
|
||||
const prepareFileList = vi.fn()
|
||||
|
||||
const { getByTestId } = await act(async () =>
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: mockClose }}>
|
||||
<TestDropzone options={{ ...defaultOptions, prepareFileList }} />
|
||||
</ToastContext.Provider>,
|
||||
),
|
||||
)
|
||||
|
||||
const dropzone = getByTestId('dropzone')
|
||||
|
||||
await act(async () => {
|
||||
const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null }
|
||||
Object.defineProperty(dropEvent, 'dataTransfer', {
|
||||
value: {
|
||||
items: [{
|
||||
getAsFile: () => null,
|
||||
webkitGetAsEntry: () => ({
|
||||
isFile: false,
|
||||
isDirectory: false,
|
||||
}),
|
||||
}],
|
||||
},
|
||||
})
|
||||
dropzone.dispatchEvent(dropEvent)
|
||||
})
|
||||
|
||||
// Should not throw and should handle gracefully
|
||||
await new Promise(resolve => setTimeout(resolve, 100))
|
||||
})
|
||||
})
|
||||
|
||||
describe('file upload', () => {
|
||||
it('should call upload with correct parameters', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id', name: 'test.pdf' })
|
||||
const onFileUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onFileUpdate }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpload).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should update progress during upload', async () => {
|
||||
let progressCallback: ((e: ProgressEvent) => void) | undefined
|
||||
|
||||
mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => {
|
||||
progressCallback = options.onprogress
|
||||
return { id: 'uploaded-id' }
|
||||
})
|
||||
|
||||
const onFileUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onFileUpdate }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpload).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
if (progressCallback) {
|
||||
act(() => {
|
||||
progressCallback!({
|
||||
lengthComputable: true,
|
||||
loaded: 50,
|
||||
total: 100,
|
||||
} as ProgressEvent)
|
||||
})
|
||||
|
||||
expect(onFileUpdate).toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle upload error', async () => {
|
||||
mockUpload.mockRejectedValue(new Error('Upload failed'))
|
||||
const onFileUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onFileUpdate }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should update file with PROGRESS_COMPLETE on success', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'uploaded-id', name: 'test.pdf' })
|
||||
const onFileUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onFileUpdate }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
const completeCalls = onFileUpdate.mock.calls.filter(
|
||||
([, progress]) => progress === PROGRESS_COMPLETE,
|
||||
)
|
||||
expect(completeCalls.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
it('should update file with PROGRESS_ERROR on failure', async () => {
|
||||
mockUpload.mockRejectedValue(new Error('Upload failed'))
|
||||
const onFileUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, onFileUpdate }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
const errorCalls = onFileUpdate.mock.calls.filter(
|
||||
([, progress]) => progress === PROGRESS_ERROR,
|
||||
)
|
||||
expect(errorCalls.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('file count validation', () => {
|
||||
it('should reject when total files exceed limit', () => {
|
||||
const existingFiles: FileItem[] = Array.from({ length: 8 }, (_, i) => ({
|
||||
fileID: `existing-${i}`,
|
||||
file: { name: `existing-${i}.pdf`, size: 1024 } as CustomFile,
|
||||
progress: 100,
|
||||
}))
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({
|
||||
...defaultOptions,
|
||||
fileList: existingFiles,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const files = Array.from({ length: 5 }, (_, i) =>
|
||||
new File(['content'], `new-${i}.pdf`, { type: 'application/pdf' }))
|
||||
|
||||
const event = {
|
||||
target: { files },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('progress constants', () => {
|
||||
it('should use PROGRESS_NOT_STARTED for new files', async () => {
|
||||
mockUpload.mockResolvedValue({ id: 'file-id' })
|
||||
|
||||
const prepareFileList = vi.fn()
|
||||
const { result } = renderHook(
|
||||
() => useFileUpload({ ...defaultOptions, prepareFileList }),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' })
|
||||
const event = {
|
||||
target: { files: [mockFile] },
|
||||
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||
|
||||
act(() => {
|
||||
result.current.fileChangeHandle(event)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
if (prepareFileList.mock.calls.length > 0) {
|
||||
const files = prepareFileList.mock.calls[0][0]
|
||||
expect(files[0].progress).toBe(PROGRESS_NOT_STARTED)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,351 @@
|
||||
'use client'
|
||||
import type { RefObject } from 'react'
|
||||
import type { CustomFile as File, FileItem } from '@/models/datasets'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import { upload } from '@/service/base'
|
||||
import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common'
|
||||
import { getFileExtension } from '@/utils/format'
|
||||
import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants'
|
||||
|
||||
export type FileUploadConfig = {
|
||||
file_size_limit: number
|
||||
batch_count_limit: number
|
||||
file_upload_limit: number
|
||||
}
|
||||
|
||||
export type UseFileUploadOptions = {
|
||||
fileList: FileItem[]
|
||||
prepareFileList: (files: FileItem[]) => void
|
||||
onFileUpdate: (fileItem: FileItem, progress: number, list: FileItem[]) => void
|
||||
onFileListUpdate?: (files: FileItem[]) => void
|
||||
onPreview: (file: File) => void
|
||||
supportBatchUpload?: boolean
|
||||
/**
|
||||
* Optional list of allowed file extensions. If not provided, fetches from API.
|
||||
* Pass this when you need custom extension filtering instead of using the global config.
|
||||
*/
|
||||
allowedExtensions?: string[]
|
||||
}
|
||||
|
||||
export type UseFileUploadReturn = {
|
||||
// Refs
|
||||
dropRef: RefObject<HTMLDivElement | null>
|
||||
dragRef: RefObject<HTMLDivElement | null>
|
||||
fileUploaderRef: RefObject<HTMLInputElement | null>
|
||||
|
||||
// State
|
||||
dragging: boolean
|
||||
|
||||
// Config
|
||||
fileUploadConfig: FileUploadConfig
|
||||
acceptTypes: string[]
|
||||
supportTypesShowNames: string
|
||||
hideUpload: boolean
|
||||
|
||||
// Handlers
|
||||
selectHandle: () => void
|
||||
fileChangeHandle: (e: React.ChangeEvent<HTMLInputElement>) => void
|
||||
removeFile: (fileID: string) => void
|
||||
handlePreview: (file: File) => void
|
||||
}
|
||||
|
||||
type FileWithPath = {
|
||||
relativePath?: string
|
||||
} & File
|
||||
|
||||
export const useFileUpload = ({
|
||||
fileList,
|
||||
prepareFileList,
|
||||
onFileUpdate,
|
||||
onFileListUpdate,
|
||||
onPreview,
|
||||
supportBatchUpload = false,
|
||||
allowedExtensions,
|
||||
}: UseFileUploadOptions): UseFileUploadReturn => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const locale = useLocale()
|
||||
|
||||
const [dragging, setDragging] = useState(false)
|
||||
const dropRef = useRef<HTMLDivElement>(null)
|
||||
const dragRef = useRef<HTMLDivElement>(null)
|
||||
const fileUploaderRef = useRef<HTMLInputElement>(null)
|
||||
const fileListRef = useRef<FileItem[]>([])
|
||||
|
||||
const hideUpload = !supportBatchUpload && fileList.length > 0
|
||||
|
||||
const { data: fileUploadConfigResponse } = useFileUploadConfig()
|
||||
const { data: supportFileTypesResponse } = useFileSupportTypes()
|
||||
// Use provided allowedExtensions or fetch from API
|
||||
const supportTypes = useMemo(
|
||||
() => allowedExtensions ?? supportFileTypesResponse?.allowed_extensions ?? [],
|
||||
[allowedExtensions, supportFileTypesResponse?.allowed_extensions],
|
||||
)
|
||||
|
||||
const supportTypesShowNames = useMemo(() => {
|
||||
const extensionMap: { [key: string]: string } = {
|
||||
md: 'markdown',
|
||||
pptx: 'pptx',
|
||||
htm: 'html',
|
||||
xlsx: 'xlsx',
|
||||
docx: 'docx',
|
||||
}
|
||||
|
||||
return [...supportTypes]
|
||||
.map(item => extensionMap[item] || item)
|
||||
.map(item => item.toLowerCase())
|
||||
.filter((item, index, self) => self.indexOf(item) === index)
|
||||
.map(item => item.toUpperCase())
|
||||
.join(locale !== LanguagesSupported[1] ? ', ' : '、 ')
|
||||
}, [supportTypes, locale])
|
||||
|
||||
const acceptTypes = useMemo(() => supportTypes.map((ext: string) => `.${ext}`), [supportTypes])
|
||||
|
||||
const fileUploadConfig = useMemo(() => ({
|
||||
file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15,
|
||||
batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1,
|
||||
file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1,
|
||||
}), [fileUploadConfigResponse, supportBatchUpload])
|
||||
|
||||
const isValid = useCallback((file: File) => {
|
||||
const { size } = file
|
||||
const ext = `.${getFileExtension(file.name)}`
|
||||
const isValidType = acceptTypes.includes(ext.toLowerCase())
|
||||
if (!isValidType)
|
||||
notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) })
|
||||
|
||||
const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024
|
||||
if (!isValidSize)
|
||||
notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) })
|
||||
|
||||
return isValidType && isValidSize
|
||||
}, [fileUploadConfig, notify, t, acceptTypes])
|
||||
|
||||
const fileUpload = useCallback(async (fileItem: FileItem): Promise<FileItem> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', fileItem.file)
|
||||
const onProgress = (e: ProgressEvent) => {
|
||||
if (e.lengthComputable) {
|
||||
const percent = Math.floor(e.loaded / e.total * 100)
|
||||
onFileUpdate(fileItem, percent, fileListRef.current)
|
||||
}
|
||||
}
|
||||
|
||||
return upload({
|
||||
xhr: new XMLHttpRequest(),
|
||||
data: formData,
|
||||
onprogress: onProgress,
|
||||
}, false, undefined, '?source=datasets')
|
||||
.then((res) => {
|
||||
const completeFile = {
|
||||
fileID: fileItem.fileID,
|
||||
file: res as unknown as File,
|
||||
progress: PROGRESS_NOT_STARTED,
|
||||
}
|
||||
const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID)
|
||||
fileListRef.current[index] = completeFile
|
||||
onFileUpdate(completeFile, PROGRESS_COMPLETE, fileListRef.current)
|
||||
return Promise.resolve({ ...completeFile })
|
||||
})
|
||||
.catch((e) => {
|
||||
const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t)
|
||||
notify({ type: 'error', message: errorMessage })
|
||||
onFileUpdate(fileItem, PROGRESS_ERROR, fileListRef.current)
|
||||
return Promise.resolve({ ...fileItem })
|
||||
})
|
||||
.finally()
|
||||
}, [notify, onFileUpdate, t])
|
||||
|
||||
const uploadBatchFiles = useCallback((bFiles: FileItem[]) => {
|
||||
bFiles.forEach(bf => (bf.progress = 0))
|
||||
return Promise.all(bFiles.map(fileUpload))
|
||||
}, [fileUpload])
|
||||
|
||||
const uploadMultipleFiles = useCallback(async (files: FileItem[]) => {
|
||||
const batchCountLimit = fileUploadConfig.batch_count_limit
|
||||
const length = files.length
|
||||
let start = 0
|
||||
let end = 0
|
||||
|
||||
while (start < length) {
|
||||
if (start + batchCountLimit > length)
|
||||
end = length
|
||||
else
|
||||
end = start + batchCountLimit
|
||||
const bFiles = files.slice(start, end)
|
||||
await uploadBatchFiles(bFiles)
|
||||
start = end
|
||||
}
|
||||
}, [fileUploadConfig, uploadBatchFiles])
|
||||
|
||||
const initialUpload = useCallback((files: File[]) => {
|
||||
const filesCountLimit = fileUploadConfig.file_upload_limit
|
||||
if (!files.length)
|
||||
return false
|
||||
|
||||
if (files.length + fileList.length > filesCountLimit && !IS_CE_EDITION) {
|
||||
notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const preparedFiles = files.map((file, index) => ({
|
||||
fileID: `file${index}-${Date.now()}`,
|
||||
file,
|
||||
progress: PROGRESS_NOT_STARTED,
|
||||
}))
|
||||
const newFiles = [...fileListRef.current, ...preparedFiles]
|
||||
prepareFileList(newFiles)
|
||||
fileListRef.current = newFiles
|
||||
uploadMultipleFiles(preparedFiles)
|
||||
}, [prepareFileList, uploadMultipleFiles, notify, t, fileList, fileUploadConfig])
|
||||
|
||||
const traverseFileEntry = useCallback(
|
||||
(entry: FileSystemEntry, prefix = ''): Promise<FileWithPath[]> => {
|
||||
return new Promise((resolve) => {
|
||||
if (entry.isFile) {
|
||||
(entry as FileSystemFileEntry).file((file: FileWithPath) => {
|
||||
file.relativePath = `${prefix}${file.name}`
|
||||
resolve([file])
|
||||
})
|
||||
}
|
||||
else if (entry.isDirectory) {
|
||||
const reader = (entry as FileSystemDirectoryEntry).createReader()
|
||||
const entries: FileSystemEntry[] = []
|
||||
const read = () => {
|
||||
reader.readEntries(async (results: FileSystemEntry[]) => {
|
||||
if (!results.length) {
|
||||
const files = await Promise.all(
|
||||
entries.map(ent =>
|
||||
traverseFileEntry(ent, `${prefix}${entry.name}/`),
|
||||
),
|
||||
)
|
||||
resolve(files.flat())
|
||||
}
|
||||
else {
|
||||
entries.push(...results)
|
||||
read()
|
||||
}
|
||||
})
|
||||
}
|
||||
read()
|
||||
}
|
||||
else {
|
||||
resolve([])
|
||||
}
|
||||
})
|
||||
},
|
||||
[],
|
||||
)
|
||||
|
||||
const handleDragEnter = useCallback((e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
if (e.target !== dragRef.current)
|
||||
setDragging(true)
|
||||
}, [])
|
||||
|
||||
const handleDragOver = useCallback((e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
}, [])
|
||||
|
||||
const handleDragLeave = useCallback((e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
if (e.target === dragRef.current)
|
||||
setDragging(false)
|
||||
}, [])
|
||||
|
||||
const handleDrop = useCallback(
|
||||
async (e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
setDragging(false)
|
||||
if (!e.dataTransfer)
|
||||
return
|
||||
const nested = await Promise.all(
|
||||
Array.from(e.dataTransfer.items).map((it) => {
|
||||
const entry = (it as DataTransferItem & { webkitGetAsEntry?: () => FileSystemEntry | null }).webkitGetAsEntry?.()
|
||||
if (entry)
|
||||
return traverseFileEntry(entry)
|
||||
const f = it.getAsFile?.()
|
||||
return f ? Promise.resolve([f as FileWithPath]) : Promise.resolve([])
|
||||
}),
|
||||
)
|
||||
let files = nested.flat()
|
||||
if (!supportBatchUpload)
|
||||
files = files.slice(0, 1)
|
||||
files = files.slice(0, fileUploadConfig.batch_count_limit)
|
||||
const valid = files.filter(isValid)
|
||||
initialUpload(valid)
|
||||
},
|
||||
[initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig],
|
||||
)
|
||||
|
||||
const selectHandle = useCallback(() => {
|
||||
if (fileUploaderRef.current)
|
||||
fileUploaderRef.current.click()
|
||||
}, [])
|
||||
|
||||
const removeFile = useCallback((fileID: string) => {
|
||||
if (fileUploaderRef.current)
|
||||
fileUploaderRef.current.value = ''
|
||||
|
||||
fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID)
|
||||
onFileListUpdate?.([...fileListRef.current])
|
||||
}, [onFileListUpdate])
|
||||
|
||||
const fileChangeHandle = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
let files = Array.from(e.target.files ?? []) as File[]
|
||||
files = files.slice(0, fileUploadConfig.batch_count_limit)
|
||||
initialUpload(files.filter(isValid))
|
||||
}, [isValid, initialUpload, fileUploadConfig])
|
||||
|
||||
const handlePreview = useCallback((file: File) => {
|
||||
if (file?.id)
|
||||
onPreview(file)
|
||||
}, [onPreview])
|
||||
|
||||
useEffect(() => {
|
||||
const dropArea = dropRef.current
|
||||
dropArea?.addEventListener('dragenter', handleDragEnter)
|
||||
dropArea?.addEventListener('dragover', handleDragOver)
|
||||
dropArea?.addEventListener('dragleave', handleDragLeave)
|
||||
dropArea?.addEventListener('drop', handleDrop)
|
||||
return () => {
|
||||
dropArea?.removeEventListener('dragenter', handleDragEnter)
|
||||
dropArea?.removeEventListener('dragover', handleDragOver)
|
||||
dropArea?.removeEventListener('dragleave', handleDragLeave)
|
||||
dropArea?.removeEventListener('drop', handleDrop)
|
||||
}
|
||||
}, [handleDragEnter, handleDragOver, handleDragLeave, handleDrop])
|
||||
|
||||
return {
|
||||
// Refs
|
||||
dropRef,
|
||||
dragRef,
|
||||
fileUploaderRef,
|
||||
|
||||
// State
|
||||
dragging,
|
||||
|
||||
// Config
|
||||
fileUploadConfig,
|
||||
acceptTypes,
|
||||
supportTypesShowNames,
|
||||
hideUpload,
|
||||
|
||||
// Handlers
|
||||
selectHandle,
|
||||
fileChangeHandle,
|
||||
removeFile,
|
||||
handlePreview,
|
||||
}
|
||||
}
|
||||
278
web/app/components/datasets/create/file-uploader/index.spec.tsx
Normal file
278
web/app/components/datasets/create/file-uploader/index.spec.tsx
Normal file
@ -0,0 +1,278 @@
|
||||
import type { CustomFile as File, FileItem } from '@/models/datasets'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PROGRESS_NOT_STARTED } from './constants'
|
||||
import FileUploader from './index'
|
||||
|
||||
// Mock react-i18next
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
'stepOne.uploader.title': 'Upload Files',
|
||||
'stepOne.uploader.button': 'Drag and drop files, or',
|
||||
'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or',
|
||||
'stepOne.uploader.browse': 'Browse',
|
||||
'stepOne.uploader.tip': 'Supports various file types',
|
||||
}
|
||||
return translations[key] || key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock ToastContext
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('use-context-selector', async () => {
|
||||
const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector')
|
||||
return {
|
||||
...actual,
|
||||
useContext: vi.fn(() => ({ notify: mockNotify })),
|
||||
}
|
||||
})
|
||||
|
||||
// Mock services
|
||||
vi.mock('@/service/base', () => ({
|
||||
upload: vi.fn().mockResolvedValue({ id: 'uploaded-id' }),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useFileUploadConfig: () => ({
|
||||
data: { file_size_limit: 15, batch_count_limit: 5, file_upload_limit: 10 },
|
||||
}),
|
||||
useFileSupportTypes: () => ({
|
||||
data: { allowed_extensions: ['pdf', 'docx', 'txt'] },
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: () => 'en-US',
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config/language', () => ({
|
||||
LanguagesSupported: ['en-US', 'zh-Hans'],
|
||||
}))
|
||||
|
||||
vi.mock('@/config', () => ({
|
||||
IS_CE_EDITION: false,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/file-uploader/utils', () => ({
|
||||
getFileUploadErrorMessage: () => 'Upload error',
|
||||
}))
|
||||
|
||||
// Mock theme
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({ theme: 'light' }),
|
||||
}))
|
||||
|
||||
vi.mock('@/types/app', () => ({
|
||||
Theme: { dark: 'dark', light: 'light' },
|
||||
}))
|
||||
|
||||
// Mock DocumentFileIcon - uses relative path from file-list-item.tsx
|
||||
vi.mock('@/app/components/datasets/common/document-file-icon', () => ({
|
||||
default: ({ extension }: { extension: string }) => <div data-testid="document-icon">{extension}</div>,
|
||||
}))
|
||||
|
||||
// Mock SimplePieChart
|
||||
vi.mock('next/dynamic', () => ({
|
||||
default: () => {
|
||||
const Component = ({ percentage }: { percentage: number }) => (
|
||||
<div data-testid="pie-chart">
|
||||
{percentage}
|
||||
%
|
||||
</div>
|
||||
)
|
||||
return Component
|
||||
},
|
||||
}))
|
||||
|
||||
describe('FileUploader', () => {
|
||||
const createMockFile = (overrides: Partial<File> = {}): File => ({
|
||||
name: 'test.pdf',
|
||||
size: 1024,
|
||||
type: 'application/pdf',
|
||||
...overrides,
|
||||
} as File)
|
||||
|
||||
const createMockFileItem = (overrides: Partial<FileItem> = {}): FileItem => ({
|
||||
fileID: `file-${Date.now()}`,
|
||||
file: createMockFile(overrides.file as Partial<File>),
|
||||
progress: PROGRESS_NOT_STARTED,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const defaultProps = {
|
||||
fileList: [] as FileItem[],
|
||||
prepareFileList: vi.fn(),
|
||||
onFileUpdate: vi.fn(),
|
||||
onFileListUpdate: vi.fn(),
|
||||
onPreview: vi.fn(),
|
||||
supportBatchUpload: true,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render the component', () => {
|
||||
render(<FileUploader {...defaultProps} />)
|
||||
expect(screen.getByText('Upload Files')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render dropzone when no files', () => {
|
||||
render(<FileUploader {...defaultProps} />)
|
||||
expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render browse button', () => {
|
||||
render(<FileUploader {...defaultProps} />)
|
||||
expect(screen.getByText('Browse')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply custom title className', () => {
|
||||
render(<FileUploader {...defaultProps} titleClassName="custom-class" />)
|
||||
const title = screen.getByText('Upload Files')
|
||||
expect(title).toHaveClass('custom-class')
|
||||
})
|
||||
})
|
||||
|
||||
describe('file list rendering', () => {
|
||||
it('should render file items when fileList has items', () => {
|
||||
const fileList = [
|
||||
createMockFileItem({ file: createMockFile({ name: 'file1.pdf' }) }),
|
||||
createMockFileItem({ file: createMockFile({ name: 'file2.pdf' }) }),
|
||||
]
|
||||
|
||||
render(<FileUploader {...defaultProps} fileList={fileList} />)
|
||||
|
||||
expect(screen.getByText('file1.pdf')).toBeInTheDocument()
|
||||
expect(screen.getByText('file2.pdf')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render document icons for files', () => {
|
||||
const fileList = [createMockFileItem()]
|
||||
render(<FileUploader {...defaultProps} fileList={fileList} />)
|
||||
|
||||
expect(screen.getByTestId('document-icon')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('batch upload mode', () => {
|
||||
it('should show dropzone with batch upload enabled', () => {
|
||||
render(<FileUploader {...defaultProps} supportBatchUpload={true} />)
|
||||
expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show single file text when batch upload disabled', () => {
|
||||
render(<FileUploader {...defaultProps} supportBatchUpload={false} />)
|
||||
expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide dropzone when not batch upload and has files', () => {
|
||||
const fileList = [createMockFileItem()]
|
||||
render(<FileUploader {...defaultProps} supportBatchUpload={false} fileList={fileList} />)
|
||||
|
||||
expect(screen.queryByText(/Drag and drop/i)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('event handlers', () => {
|
||||
it('should handle file preview click', () => {
|
||||
const onPreview = vi.fn()
|
||||
const fileItem = createMockFileItem({
|
||||
file: createMockFile({ id: 'file-id' } as Partial<File>),
|
||||
})
|
||||
|
||||
const { container } = render(<FileUploader {...defaultProps} fileList={[fileItem]} onPreview={onPreview} />)
|
||||
|
||||
// Find the file list item container by its class pattern
|
||||
const fileElement = container.querySelector('[class*="flex h-12"]')
|
||||
if (fileElement)
|
||||
fireEvent.click(fileElement)
|
||||
|
||||
expect(onPreview).toHaveBeenCalledWith(fileItem.file)
|
||||
})
|
||||
|
||||
it('should handle file remove click', () => {
|
||||
const onFileListUpdate = vi.fn()
|
||||
const fileItem = createMockFileItem()
|
||||
|
||||
const { container } = render(
|
||||
<FileUploader {...defaultProps} fileList={[fileItem]} onFileListUpdate={onFileListUpdate} />,
|
||||
)
|
||||
|
||||
// Find the delete button (the span with cursor-pointer containing the icon)
|
||||
const deleteButtons = container.querySelectorAll('[class*="cursor-pointer"]')
|
||||
// Get the last one which should be the delete button (not the browse label)
|
||||
const deleteButton = deleteButtons[deleteButtons.length - 1]
|
||||
if (deleteButton)
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onFileListUpdate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle browse button click', () => {
|
||||
render(<FileUploader {...defaultProps} />)
|
||||
|
||||
// The browse label should trigger file input click
|
||||
const browseLabel = screen.getByText('Browse')
|
||||
expect(browseLabel).toHaveClass('cursor-pointer')
|
||||
})
|
||||
})
|
||||
|
||||
describe('upload progress', () => {
|
||||
it('should show progress chart for uploading files', () => {
|
||||
const fileItem = createMockFileItem({ progress: 50 })
|
||||
render(<FileUploader {...defaultProps} fileList={[fileItem]} />)
|
||||
|
||||
expect(screen.getByTestId('pie-chart')).toBeInTheDocument()
|
||||
expect(screen.getByText('50%')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show progress chart for completed files', () => {
|
||||
const fileItem = createMockFileItem({ progress: 100 })
|
||||
render(<FileUploader {...defaultProps} fileList={[fileItem]} />)
|
||||
|
||||
expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show progress chart for not started files', () => {
|
||||
const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED })
|
||||
render(<FileUploader {...defaultProps} fileList={[fileItem]} />)
|
||||
|
||||
expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('multiple files', () => {
|
||||
it('should render all files in the list', () => {
|
||||
const fileList = [
|
||||
createMockFileItem({ fileID: 'f1', file: createMockFile({ name: 'doc1.pdf' }) }),
|
||||
createMockFileItem({ fileID: 'f2', file: createMockFile({ name: 'doc2.docx' }) }),
|
||||
createMockFileItem({ fileID: 'f3', file: createMockFile({ name: 'doc3.txt' }) }),
|
||||
]
|
||||
|
||||
render(<FileUploader {...defaultProps} fileList={fileList} />)
|
||||
|
||||
expect(screen.getByText('doc1.pdf')).toBeInTheDocument()
|
||||
expect(screen.getByText('doc2.docx')).toBeInTheDocument()
|
||||
expect(screen.getByText('doc3.txt')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('styling', () => {
|
||||
it('should have correct container width', () => {
|
||||
const { container } = render(<FileUploader {...defaultProps} />)
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveClass('w-[640px]')
|
||||
})
|
||||
|
||||
it('should have proper spacing', () => {
|
||||
const { container } = render(<FileUploader {...defaultProps} />)
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveClass('mb-5')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,23 +1,10 @@
|
||||
'use client'
|
||||
import type { CustomFile as File, FileItem } from '@/models/datasets'
|
||||
import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils'
|
||||
import SimplePieChart from '@/app/components/base/simple-pie-chart'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import { upload } from '@/service/base'
|
||||
import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common'
|
||||
import { Theme } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import DocumentFileIcon from '../../common/document-file-icon'
|
||||
import FileListItem from './components/file-list-item'
|
||||
import UploadDropzone from './components/upload-dropzone'
|
||||
import { useFileUpload } from './hooks/use-file-upload'
|
||||
|
||||
type IFileUploaderProps = {
|
||||
fileList: FileItem[]
|
||||
@ -39,358 +26,62 @@ const FileUploader = ({
|
||||
supportBatchUpload = false,
|
||||
}: IFileUploaderProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const locale = useLocale()
|
||||
const [dragging, setDragging] = useState(false)
|
||||
const dropRef = useRef<HTMLDivElement>(null)
|
||||
const dragRef = useRef<HTMLDivElement>(null)
|
||||
const fileUploader = useRef<HTMLInputElement>(null)
|
||||
const hideUpload = !supportBatchUpload && fileList.length > 0
|
||||
|
||||
const { data: fileUploadConfigResponse } = useFileUploadConfig()
|
||||
const { data: supportFileTypesResponse } = useFileSupportTypes()
|
||||
const supportTypes = supportFileTypesResponse?.allowed_extensions || []
|
||||
const supportTypesShowNames = (() => {
|
||||
const extensionMap: { [key: string]: string } = {
|
||||
md: 'markdown',
|
||||
pptx: 'pptx',
|
||||
htm: 'html',
|
||||
xlsx: 'xlsx',
|
||||
docx: 'docx',
|
||||
}
|
||||
|
||||
return [...supportTypes]
|
||||
.map(item => extensionMap[item] || item) // map to standardized extension
|
||||
.map(item => item.toLowerCase()) // convert to lower case
|
||||
.filter((item, index, self) => self.indexOf(item) === index) // remove duplicates
|
||||
.map(item => item.toUpperCase()) // convert to upper case
|
||||
.join(locale !== LanguagesSupported[1] ? ', ' : '、 ')
|
||||
})()
|
||||
const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`)
|
||||
const fileUploadConfig = useMemo(() => ({
|
||||
file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15,
|
||||
batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1,
|
||||
file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1,
|
||||
}), [fileUploadConfigResponse, supportBatchUpload])
|
||||
|
||||
const fileListRef = useRef<FileItem[]>([])
|
||||
|
||||
// utils
|
||||
const getFileType = (currentFile: File) => {
|
||||
if (!currentFile)
|
||||
return ''
|
||||
|
||||
const arr = currentFile.name.split('.')
|
||||
return arr[arr.length - 1]
|
||||
}
|
||||
|
||||
const getFileSize = (size: number) => {
|
||||
if (size / 1024 < 10)
|
||||
return `${(size / 1024).toFixed(2)}KB`
|
||||
|
||||
return `${(size / 1024 / 1024).toFixed(2)}MB`
|
||||
}
|
||||
|
||||
const isValid = useCallback((file: File) => {
|
||||
const { size } = file
|
||||
const ext = `.${getFileType(file)}`
|
||||
const isValidType = ACCEPTS.includes(ext.toLowerCase())
|
||||
if (!isValidType)
|
||||
notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) })
|
||||
|
||||
const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024
|
||||
if (!isValidSize)
|
||||
notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) })
|
||||
|
||||
return isValidType && isValidSize
|
||||
}, [fileUploadConfig, notify, t, ACCEPTS])
|
||||
|
||||
const fileUpload = useCallback(async (fileItem: FileItem): Promise<FileItem> => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', fileItem.file)
|
||||
const onProgress = (e: ProgressEvent) => {
|
||||
if (e.lengthComputable) {
|
||||
const percent = Math.floor(e.loaded / e.total * 100)
|
||||
onFileUpdate(fileItem, percent, fileListRef.current)
|
||||
}
|
||||
}
|
||||
|
||||
return upload({
|
||||
xhr: new XMLHttpRequest(),
|
||||
data: formData,
|
||||
onprogress: onProgress,
|
||||
}, false, undefined, '?source=datasets')
|
||||
.then((res) => {
|
||||
const completeFile = {
|
||||
fileID: fileItem.fileID,
|
||||
file: res as unknown as File,
|
||||
progress: -1,
|
||||
}
|
||||
const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID)
|
||||
fileListRef.current[index] = completeFile
|
||||
onFileUpdate(completeFile, 100, fileListRef.current)
|
||||
return Promise.resolve({ ...completeFile })
|
||||
})
|
||||
.catch((e) => {
|
||||
const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t)
|
||||
notify({ type: 'error', message: errorMessage })
|
||||
onFileUpdate(fileItem, -2, fileListRef.current)
|
||||
return Promise.resolve({ ...fileItem })
|
||||
})
|
||||
.finally()
|
||||
}, [fileListRef, notify, onFileUpdate, t])
|
||||
|
||||
const uploadBatchFiles = useCallback((bFiles: FileItem[]) => {
|
||||
bFiles.forEach(bf => (bf.progress = 0))
|
||||
return Promise.all(bFiles.map(fileUpload))
|
||||
}, [fileUpload])
|
||||
|
||||
const uploadMultipleFiles = useCallback(async (files: FileItem[]) => {
|
||||
const batchCountLimit = fileUploadConfig.batch_count_limit
|
||||
const length = files.length
|
||||
let start = 0
|
||||
let end = 0
|
||||
|
||||
while (start < length) {
|
||||
if (start + batchCountLimit > length)
|
||||
end = length
|
||||
else
|
||||
end = start + batchCountLimit
|
||||
const bFiles = files.slice(start, end)
|
||||
await uploadBatchFiles(bFiles)
|
||||
start = end
|
||||
}
|
||||
}, [fileUploadConfig, uploadBatchFiles])
|
||||
|
||||
const initialUpload = useCallback((files: File[]) => {
|
||||
const filesCountLimit = fileUploadConfig.file_upload_limit
|
||||
if (!files.length)
|
||||
return false
|
||||
|
||||
if (files.length + fileList.length > filesCountLimit && !IS_CE_EDITION) {
|
||||
notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const preparedFiles = files.map((file, index) => ({
|
||||
fileID: `file${index}-${Date.now()}`,
|
||||
file,
|
||||
progress: -1,
|
||||
}))
|
||||
const newFiles = [...fileListRef.current, ...preparedFiles]
|
||||
prepareFileList(newFiles)
|
||||
fileListRef.current = newFiles
|
||||
uploadMultipleFiles(preparedFiles)
|
||||
}, [prepareFileList, uploadMultipleFiles, notify, t, fileList, fileUploadConfig])
|
||||
|
||||
const handleDragEnter = (e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
if (e.target !== dragRef.current)
|
||||
setDragging(true)
|
||||
}
|
||||
const handleDragOver = (e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
}
|
||||
const handleDragLeave = (e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
if (e.target === dragRef.current)
|
||||
setDragging(false)
|
||||
}
|
||||
type FileWithPath = {
|
||||
relativePath?: string
|
||||
} & File
|
||||
const traverseFileEntry = useCallback(
|
||||
(entry: any, prefix = ''): Promise<FileWithPath[]> => {
|
||||
return new Promise((resolve) => {
|
||||
if (entry.isFile) {
|
||||
entry.file((file: FileWithPath) => {
|
||||
file.relativePath = `${prefix}${file.name}`
|
||||
resolve([file])
|
||||
})
|
||||
}
|
||||
else if (entry.isDirectory) {
|
||||
const reader = entry.createReader()
|
||||
const entries: any[] = []
|
||||
const read = () => {
|
||||
reader.readEntries(async (results: FileSystemEntry[]) => {
|
||||
if (!results.length) {
|
||||
const files = await Promise.all(
|
||||
entries.map(ent =>
|
||||
traverseFileEntry(ent, `${prefix}${entry.name}/`),
|
||||
),
|
||||
)
|
||||
resolve(files.flat())
|
||||
}
|
||||
else {
|
||||
entries.push(...results)
|
||||
read()
|
||||
}
|
||||
})
|
||||
}
|
||||
read()
|
||||
}
|
||||
else {
|
||||
resolve([])
|
||||
}
|
||||
})
|
||||
},
|
||||
[],
|
||||
)
|
||||
|
||||
const handleDrop = useCallback(
|
||||
async (e: DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
setDragging(false)
|
||||
if (!e.dataTransfer)
|
||||
return
|
||||
const nested = await Promise.all(
|
||||
Array.from(e.dataTransfer.items).map((it) => {
|
||||
const entry = (it as any).webkitGetAsEntry?.()
|
||||
if (entry)
|
||||
return traverseFileEntry(entry)
|
||||
const f = it.getAsFile?.()
|
||||
return f ? Promise.resolve([f]) : Promise.resolve([])
|
||||
}),
|
||||
)
|
||||
let files = nested.flat()
|
||||
if (!supportBatchUpload)
|
||||
files = files.slice(0, 1)
|
||||
files = files.slice(0, fileUploadConfig.batch_count_limit)
|
||||
const valid = files.filter(isValid)
|
||||
initialUpload(valid)
|
||||
},
|
||||
[initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig],
|
||||
)
|
||||
const selectHandle = () => {
|
||||
if (fileUploader.current)
|
||||
fileUploader.current.click()
|
||||
}
|
||||
|
||||
const removeFile = (fileID: string) => {
|
||||
if (fileUploader.current)
|
||||
fileUploader.current.value = ''
|
||||
|
||||
fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID)
|
||||
onFileListUpdate?.([...fileListRef.current])
|
||||
}
|
||||
const fileChangeHandle = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
let files = Array.from(e.target.files ?? []) as File[]
|
||||
files = files.slice(0, fileUploadConfig.batch_count_limit)
|
||||
initialUpload(files.filter(isValid))
|
||||
}, [isValid, initialUpload, fileUploadConfig])
|
||||
|
||||
const { theme } = useTheme()
|
||||
const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme])
|
||||
|
||||
useEffect(() => {
|
||||
dropRef.current?.addEventListener('dragenter', handleDragEnter)
|
||||
dropRef.current?.addEventListener('dragover', handleDragOver)
|
||||
dropRef.current?.addEventListener('dragleave', handleDragLeave)
|
||||
dropRef.current?.addEventListener('drop', handleDrop)
|
||||
return () => {
|
||||
dropRef.current?.removeEventListener('dragenter', handleDragEnter)
|
||||
dropRef.current?.removeEventListener('dragover', handleDragOver)
|
||||
dropRef.current?.removeEventListener('dragleave', handleDragLeave)
|
||||
dropRef.current?.removeEventListener('drop', handleDrop)
|
||||
}
|
||||
}, [handleDrop])
|
||||
const {
|
||||
dropRef,
|
||||
dragRef,
|
||||
fileUploaderRef,
|
||||
dragging,
|
||||
fileUploadConfig,
|
||||
acceptTypes,
|
||||
supportTypesShowNames,
|
||||
hideUpload,
|
||||
selectHandle,
|
||||
fileChangeHandle,
|
||||
removeFile,
|
||||
handlePreview,
|
||||
} = useFileUpload({
|
||||
fileList,
|
||||
prepareFileList,
|
||||
onFileUpdate,
|
||||
onFileListUpdate,
|
||||
onPreview,
|
||||
supportBatchUpload,
|
||||
})
|
||||
|
||||
return (
|
||||
<div className="mb-5 w-[640px]">
|
||||
<div className={cn('mb-1 text-sm font-semibold leading-6 text-text-secondary', titleClassName)}>
|
||||
{t('stepOne.uploader.title', { ns: 'datasetCreation' })}
|
||||
</div>
|
||||
|
||||
{!hideUpload && (
|
||||
<input
|
||||
ref={fileUploader}
|
||||
id="fileUploader"
|
||||
className="hidden"
|
||||
type="file"
|
||||
multiple={supportBatchUpload}
|
||||
accept={ACCEPTS.join(',')}
|
||||
onChange={fileChangeHandle}
|
||||
<UploadDropzone
|
||||
dropRef={dropRef}
|
||||
dragRef={dragRef}
|
||||
fileUploaderRef={fileUploaderRef}
|
||||
dragging={dragging}
|
||||
supportBatchUpload={supportBatchUpload}
|
||||
supportTypesShowNames={supportTypesShowNames}
|
||||
fileUploadConfig={fileUploadConfig}
|
||||
acceptTypes={acceptTypes}
|
||||
onSelectFile={selectHandle}
|
||||
onFileChange={fileChangeHandle}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className={cn('mb-1 text-sm font-semibold leading-6 text-text-secondary', titleClassName)}>{t('stepOne.uploader.title', { ns: 'datasetCreation' })}</div>
|
||||
|
||||
{!hideUpload && (
|
||||
<div ref={dropRef} className={cn('relative mb-2 box-border flex min-h-20 max-w-[640px] flex-col items-center justify-center gap-1 rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg px-4 py-3 text-xs leading-4 text-text-tertiary', dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
|
||||
<div className="flex min-h-5 items-center justify-center text-sm leading-4 text-text-secondary">
|
||||
<RiUploadCloud2Line className="mr-2 size-5" />
|
||||
|
||||
<span>
|
||||
{supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })}
|
||||
{supportTypes.length > 0 && (
|
||||
<label className="ml-1 cursor-pointer text-text-accent" onClick={selectHandle}>{t('stepOne.uploader.browse', { ns: 'datasetCreation' })}</label>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
{t('stepOne.uploader.tip', {
|
||||
ns: 'datasetCreation',
|
||||
size: fileUploadConfig.file_size_limit,
|
||||
supportTypes: supportTypesShowNames,
|
||||
batchCount: fileUploadConfig.batch_count_limit,
|
||||
totalCount: fileUploadConfig.file_upload_limit,
|
||||
})}
|
||||
</div>
|
||||
{dragging && <div ref={dragRef} className="absolute left-0 top-0 h-full w-full" />}
|
||||
{fileList.length > 0 && (
|
||||
<div className="max-w-[640px] cursor-default space-y-1">
|
||||
{fileList.map(fileItem => (
|
||||
<FileListItem
|
||||
key={fileItem.fileID}
|
||||
fileItem={fileItem}
|
||||
onPreview={handlePreview}
|
||||
onRemove={removeFile}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div className="max-w-[640px] cursor-default space-y-1">
|
||||
|
||||
{fileList.map((fileItem, index) => (
|
||||
<div
|
||||
key={`${fileItem.fileID}-${index}`}
|
||||
onClick={() => fileItem.file?.id && onPreview(fileItem.file)}
|
||||
className={cn(
|
||||
'flex h-12 max-w-[640px] items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary shadow-xs',
|
||||
// 'border-state-destructive-border bg-state-destructive-hover',
|
||||
)}
|
||||
>
|
||||
<div className="flex w-12 shrink-0 items-center justify-center">
|
||||
<DocumentFileIcon
|
||||
size="xl"
|
||||
className="shrink-0"
|
||||
name={fileItem.file.name}
|
||||
extension={getFileType(fileItem.file)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex shrink grow flex-col gap-0.5">
|
||||
<div className="flex w-full">
|
||||
<div className="w-0 grow truncate text-sm leading-4 text-text-secondary">{fileItem.file.name}</div>
|
||||
</div>
|
||||
<div className="w-full truncate leading-3 text-text-tertiary">
|
||||
<span className="uppercase">{getFileType(fileItem.file)}</span>
|
||||
<span className="px-1 text-text-quaternary">·</span>
|
||||
<span>{getFileSize(fileItem.file.size)}</span>
|
||||
{/* <span className='px-1 text-text-quaternary'>·</span>
|
||||
<span>10k characters</span> */}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex w-16 shrink-0 items-center justify-end gap-1 pr-3">
|
||||
{/* <span className="flex justify-center items-center w-6 h-6 cursor-pointer">
|
||||
<RiErrorWarningFill className='size-4 text-text-warning' />
|
||||
</span> */}
|
||||
{(fileItem.progress < 100 && fileItem.progress >= 0) && (
|
||||
// <div className={s.percent}>{`${fileItem.progress}%`}</div>
|
||||
<SimplePieChart percentage={fileItem.progress} stroke={chartColor} fill={chartColor} animationDuration={0} />
|
||||
)}
|
||||
<span
|
||||
className="flex h-6 w-6 cursor-pointer items-center justify-center"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
removeFile(fileItem.fileID)
|
||||
}}
|
||||
>
|
||||
<RiDeleteBinLine className="size-4 text-text-tertiary" />
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
|
||||
@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
|
||||
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import RadioCard from '@/app/components/base/radio-card'
|
||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import FileList from '../../assets/file-list-3-fill.svg'
|
||||
import Note from '../../assets/note-mod.svg'
|
||||
@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
|
||||
@ -0,0 +1,262 @@
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { render } from '@testing-library/react'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
import DocumentSourceIcon from './document-source-icon'
|
||||
|
||||
const createMockDoc = (overrides: Record<string, unknown> = {}): SimpleDocumentDetail => ({
|
||||
id: 'doc-1',
|
||||
position: 1,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
data_source_info: {},
|
||||
data_source_detail_dict: {},
|
||||
dataset_process_rule_id: 'rule-1',
|
||||
dataset_id: 'dataset-1',
|
||||
batch: 'batch-1',
|
||||
name: 'test-document.txt',
|
||||
created_from: 'web',
|
||||
created_by: 'user-1',
|
||||
created_at: Date.now(),
|
||||
tokens: 100,
|
||||
indexing_status: 'completed',
|
||||
error: null,
|
||||
enabled: true,
|
||||
disabled_at: null,
|
||||
disabled_by: null,
|
||||
archived: false,
|
||||
archived_reason: null,
|
||||
archived_by: null,
|
||||
archived_at: null,
|
||||
updated_at: Date.now(),
|
||||
doc_type: null,
|
||||
doc_metadata: undefined,
|
||||
doc_language: 'en',
|
||||
display_status: 'available',
|
||||
word_count: 100,
|
||||
hit_count: 10,
|
||||
doc_form: 'text_model',
|
||||
...overrides,
|
||||
}) as unknown as SimpleDocumentDetail
|
||||
|
||||
describe('DocumentSourceIcon', () => {
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
const doc = createMockDoc()
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Local File Icon', () => {
|
||||
it('should render FileTypeIcon for FILE data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.FILE,
|
||||
data_source_info: {
|
||||
upload_file: { extension: 'pdf' },
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} fileType="pdf" />)
|
||||
const icon = container.querySelector('svg, img')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render FileTypeIcon for localFile data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.localFile,
|
||||
created_from: 'rag-pipeline',
|
||||
data_source_info: {
|
||||
extension: 'docx',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
const icon = container.querySelector('svg, img')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use extension from upload_file for legacy data source', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.FILE,
|
||||
created_from: 'web',
|
||||
data_source_info: {
|
||||
upload_file: { extension: 'txt' },
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use fileType prop as fallback for extension', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.FILE,
|
||||
created_from: 'web',
|
||||
data_source_info: {},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} fileType="csv" />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Notion Icon', () => {
|
||||
it('should render NotionIcon for NOTION data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.NOTION,
|
||||
created_from: 'web',
|
||||
data_source_info: {
|
||||
notion_page_icon: 'https://notion.so/icon.png',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render NotionIcon for onlineDocument data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.onlineDocument,
|
||||
created_from: 'rag-pipeline',
|
||||
data_source_info: {
|
||||
page: { page_icon: 'https://notion.so/icon.png' },
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use page_icon for rag-pipeline created documents', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.NOTION,
|
||||
created_from: 'rag-pipeline',
|
||||
data_source_info: {
|
||||
page: { page_icon: 'https://notion.so/custom-icon.png' },
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Web Crawl Icon', () => {
|
||||
it('should render globe icon for WEB data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.WEB,
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).toBeInTheDocument()
|
||||
expect(icon).toHaveClass('mr-1.5')
|
||||
expect(icon).toHaveClass('size-4')
|
||||
})
|
||||
|
||||
it('should render globe icon for websiteCrawl data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.websiteCrawl,
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Online Drive Icon', () => {
|
||||
it('should render FileTypeIcon for onlineDrive data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.onlineDrive,
|
||||
data_source_info: {
|
||||
name: 'document.xlsx',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should extract extension from file name', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.onlineDrive,
|
||||
data_source_info: {
|
||||
name: 'spreadsheet.xlsx',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle file name without extension', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.onlineDrive,
|
||||
data_source_info: {
|
||||
name: 'noextension',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty file name', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.onlineDrive,
|
||||
data_source_info: {
|
||||
name: '',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle hidden files (starting with dot)', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DatasourceType.onlineDrive,
|
||||
data_source_info: {
|
||||
name: '.gitignore',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Unknown Data Source Type', () => {
|
||||
it('should return null for unknown data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: 'unknown',
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined data_source_info', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.FILE,
|
||||
data_source_info: undefined,
|
||||
})
|
||||
|
||||
const { container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should memoize the component', () => {
|
||||
const doc = createMockDoc()
|
||||
const { rerender, container } = render(<DocumentSourceIcon doc={doc} />)
|
||||
|
||||
const firstRender = container.innerHTML
|
||||
rerender(<DocumentSourceIcon doc={doc} />)
|
||||
expect(container.innerHTML).toBe(firstRender)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,100 @@
|
||||
import type { FC } from 'react'
|
||||
import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, OnlineDriveInfo, SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { RiGlobalLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
|
||||
type DocumentSourceIconProps = {
|
||||
doc: SimpleDocumentDetail
|
||||
fileType?: string
|
||||
}
|
||||
|
||||
const isLocalFile = (dataSourceType: DataSourceType | DatasourceType) => {
|
||||
return dataSourceType === DatasourceType.localFile || dataSourceType === DataSourceType.FILE
|
||||
}
|
||||
|
||||
const isOnlineDocument = (dataSourceType: DataSourceType | DatasourceType) => {
|
||||
return dataSourceType === DatasourceType.onlineDocument || dataSourceType === DataSourceType.NOTION
|
||||
}
|
||||
|
||||
const isWebsiteCrawl = (dataSourceType: DataSourceType | DatasourceType) => {
|
||||
return dataSourceType === DatasourceType.websiteCrawl || dataSourceType === DataSourceType.WEB
|
||||
}
|
||||
|
||||
const isOnlineDrive = (dataSourceType: DataSourceType | DatasourceType) => {
|
||||
return dataSourceType === DatasourceType.onlineDrive
|
||||
}
|
||||
|
||||
const isCreateFromRAGPipeline = (createdFrom: string) => {
|
||||
return createdFrom === 'rag-pipeline'
|
||||
}
|
||||
|
||||
const getFileExtension = (fileName: string): string => {
|
||||
if (!fileName)
|
||||
return ''
|
||||
const parts = fileName.split('.')
|
||||
if (parts.length <= 1 || (parts[0] === '' && parts.length === 2))
|
||||
return ''
|
||||
return parts[parts.length - 1].toLowerCase()
|
||||
}
|
||||
|
||||
const DocumentSourceIcon: FC<DocumentSourceIconProps> = React.memo(({
|
||||
doc,
|
||||
fileType,
|
||||
}) => {
|
||||
if (isOnlineDocument(doc.data_source_type)) {
|
||||
return (
|
||||
<NotionIcon
|
||||
className="mr-1.5"
|
||||
type="page"
|
||||
src={
|
||||
isCreateFromRAGPipeline(doc.created_from)
|
||||
? (doc.data_source_info as OnlineDocumentInfo).page.page_icon
|
||||
: (doc.data_source_info as LegacyDataSourceInfo).notion_page_icon
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (isLocalFile(doc.data_source_type)) {
|
||||
return (
|
||||
<FileTypeIcon
|
||||
type={
|
||||
extensionToFileType(
|
||||
isCreateFromRAGPipeline(doc.created_from)
|
||||
? (doc?.data_source_info as LocalFileInfo)?.extension
|
||||
: ((doc?.data_source_info as LegacyDataSourceInfo)?.upload_file?.extension ?? fileType),
|
||||
)
|
||||
}
|
||||
className="mr-1.5"
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (isOnlineDrive(doc.data_source_type)) {
|
||||
return (
|
||||
<FileTypeIcon
|
||||
type={
|
||||
extensionToFileType(
|
||||
getFileExtension((doc?.data_source_info as unknown as OnlineDriveInfo)?.name),
|
||||
)
|
||||
}
|
||||
className="mr-1.5"
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (isWebsiteCrawl(doc.data_source_type)) {
|
||||
return <RiGlobalLine className="mr-1.5 size-4" />
|
||||
}
|
||||
|
||||
return null
|
||||
})
|
||||
|
||||
DocumentSourceIcon.displayName = 'DocumentSourceIcon'
|
||||
|
||||
export default DocumentSourceIcon
|
||||
@ -0,0 +1,342 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import DocumentTableRow from './document-table-row'
|
||||
|
||||
const mockPush = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}))
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false, gcTime: 0 },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<table>
|
||||
<tbody>
|
||||
{children}
|
||||
</tbody>
|
||||
</table>
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
type LocalDoc = SimpleDocumentDetail & { percent?: number }
|
||||
|
||||
const createMockDoc = (overrides: Record<string, unknown> = {}): LocalDoc => ({
|
||||
id: 'doc-1',
|
||||
position: 1,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
data_source_info: {},
|
||||
data_source_detail_dict: {
|
||||
upload_file: { name: 'test.txt', extension: 'txt' },
|
||||
},
|
||||
dataset_process_rule_id: 'rule-1',
|
||||
dataset_id: 'dataset-1',
|
||||
batch: 'batch-1',
|
||||
name: 'test-document.txt',
|
||||
created_from: 'web',
|
||||
created_by: 'user-1',
|
||||
created_at: Date.now(),
|
||||
tokens: 100,
|
||||
indexing_status: 'completed',
|
||||
error: null,
|
||||
enabled: true,
|
||||
disabled_at: null,
|
||||
disabled_by: null,
|
||||
archived: false,
|
||||
archived_reason: null,
|
||||
archived_by: null,
|
||||
archived_at: null,
|
||||
updated_at: Date.now(),
|
||||
doc_type: null,
|
||||
doc_metadata: undefined,
|
||||
doc_language: 'en',
|
||||
display_status: 'available',
|
||||
word_count: 500,
|
||||
hit_count: 10,
|
||||
doc_form: 'text_model',
|
||||
...overrides,
|
||||
}) as unknown as LocalDoc
|
||||
|
||||
// Helper to find the custom checkbox div (Checkbox component renders as a div, not a native checkbox)
|
||||
const findCheckbox = (container: HTMLElement): HTMLElement | null => {
|
||||
return container.querySelector('[class*="shadow-xs"]')
|
||||
}
|
||||
|
||||
describe('DocumentTableRow', () => {
|
||||
const defaultProps = {
|
||||
doc: createMockDoc(),
|
||||
index: 0,
|
||||
datasetId: 'dataset-1',
|
||||
isSelected: false,
|
||||
isGeneralMode: true,
|
||||
isQAMode: false,
|
||||
embeddingAvailable: true,
|
||||
selectedIds: [],
|
||||
onSelectOne: vi.fn(),
|
||||
onSelectedIdChange: vi.fn(),
|
||||
onShowRenameModal: vi.fn(),
|
||||
onUpdate: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('test-document.txt')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render index number correctly', () => {
|
||||
render(<DocumentTableRow {...defaultProps} index={5} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('6')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render document name with tooltip', () => {
|
||||
render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('test-document.txt')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render checkbox element', () => {
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() })
|
||||
const checkbox = findCheckbox(container)
|
||||
expect(checkbox).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Selection', () => {
|
||||
it('should show check icon when isSelected is true', () => {
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} isSelected />, { wrapper: createWrapper() })
|
||||
// When selected, the checkbox should have a check icon (RiCheckLine svg)
|
||||
const checkbox = findCheckbox(container)
|
||||
expect(checkbox).toBeInTheDocument()
|
||||
const checkIcon = checkbox?.querySelector('svg')
|
||||
expect(checkIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show check icon when isSelected is false', () => {
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} isSelected={false} />, { wrapper: createWrapper() })
|
||||
const checkbox = findCheckbox(container)
|
||||
expect(checkbox).toBeInTheDocument()
|
||||
// When not selected, there should be no check icon inside the checkbox
|
||||
const checkIcon = checkbox?.querySelector('svg')
|
||||
expect(checkIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSelectOne when checkbox is clicked', () => {
|
||||
const onSelectOne = vi.fn()
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} onSelectOne={onSelectOne} />, { wrapper: createWrapper() })
|
||||
|
||||
const checkbox = findCheckbox(container)
|
||||
if (checkbox) {
|
||||
fireEvent.click(checkbox)
|
||||
expect(onSelectOne).toHaveBeenCalledWith('doc-1')
|
||||
}
|
||||
})
|
||||
|
||||
it('should stop propagation when checkbox container is clicked', () => {
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
// Click the div containing the checkbox (which has stopPropagation)
|
||||
const checkboxContainer = container.querySelector('td')?.querySelector('div')
|
||||
if (checkboxContainer) {
|
||||
fireEvent.click(checkboxContainer)
|
||||
expect(mockPush).not.toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Row Navigation', () => {
|
||||
it('should navigate to document detail on row click', () => {
|
||||
render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
const row = screen.getByRole('row')
|
||||
fireEvent.click(row)
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1')
|
||||
})
|
||||
|
||||
it('should navigate with correct datasetId and documentId', () => {
|
||||
render(
|
||||
<DocumentTableRow
|
||||
{...defaultProps}
|
||||
datasetId="custom-dataset"
|
||||
doc={createMockDoc({ id: 'custom-doc' })}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
const row = screen.getByRole('row')
|
||||
fireEvent.click(row)
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/custom-dataset/documents/custom-doc')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Word Count Display', () => {
|
||||
it('should display word count less than 1000 as is', () => {
|
||||
const doc = createMockDoc({ word_count: 500 })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('500')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display word count 1000 or more in k format', () => {
|
||||
const doc = createMockDoc({ word_count: 1500 })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('1.5k')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display 0 with empty style when word_count is 0', () => {
|
||||
const doc = createMockDoc({ word_count: 0 })
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
const zeroCells = container.querySelectorAll('.text-text-tertiary')
|
||||
expect(zeroCells.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should handle undefined word_count', () => {
|
||||
const doc = createMockDoc({ word_count: undefined as unknown as number })
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(container).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Hit Count Display', () => {
|
||||
it('should display hit count less than 1000 as is', () => {
|
||||
const doc = createMockDoc({ hit_count: 100 })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('100')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display hit count 1000 or more in k format', () => {
|
||||
const doc = createMockDoc({ hit_count: 2500 })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('2.5k')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display 0 with empty style when hit_count is 0', () => {
|
||||
const doc = createMockDoc({ hit_count: 0 })
|
||||
const { container } = render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
const zeroCells = container.querySelectorAll('.text-text-tertiary')
|
||||
expect(zeroCells.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Chunking Mode', () => {
|
||||
it('should render ChunkingModeLabel with general mode', () => {
|
||||
render(<DocumentTableRow {...defaultProps} isGeneralMode isQAMode={false} />, { wrapper: createWrapper() })
|
||||
// ChunkingModeLabel should be rendered
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render ChunkingModeLabel with QA mode', () => {
|
||||
render(<DocumentTableRow {...defaultProps} isGeneralMode={false} isQAMode />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Summary Status', () => {
|
||||
it('should render SummaryStatus when summary_index_status is present', () => {
|
||||
const doc = createMockDoc({ summary_index_status: 'completed' })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render SummaryStatus when summary_index_status is absent', () => {
|
||||
const doc = createMockDoc({ summary_index_status: undefined })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rename Action', () => {
|
||||
it('should call onShowRenameModal when rename button is clicked', () => {
|
||||
const onShowRenameModal = vi.fn()
|
||||
const { container } = render(
|
||||
<DocumentTableRow {...defaultProps} onShowRenameModal={onShowRenameModal} />,
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
// Find the rename button by finding the RiEditLine icon's parent
|
||||
const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md')
|
||||
if (renameButtons.length > 0) {
|
||||
fireEvent.click(renameButtons[0])
|
||||
expect(onShowRenameModal).toHaveBeenCalledWith(defaultProps.doc)
|
||||
expect(mockPush).not.toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Operations', () => {
|
||||
it('should pass selectedIds to Operations component', () => {
|
||||
render(<DocumentTableRow {...defaultProps} selectedIds={['doc-1', 'doc-2']} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass onSelectedIdChange to Operations component', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
render(<DocumentTableRow {...defaultProps} onSelectedIdChange={onSelectedIdChange} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Document Source Icon', () => {
|
||||
it('should render with FILE data source type', () => {
|
||||
const doc = createMockDoc({ data_source_type: DataSourceType.FILE })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with NOTION data source type', () => {
|
||||
const doc = createMockDoc({
|
||||
data_source_type: DataSourceType.NOTION,
|
||||
data_source_info: { notion_page_icon: 'icon.png' },
|
||||
})
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with WEB data source type', () => {
|
||||
const doc = createMockDoc({ data_source_type: DataSourceType.WEB })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle document with very long name', () => {
|
||||
const doc = createMockDoc({ name: `${'a'.repeat(500)}.txt` })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle document with special characters in name', () => {
|
||||
const doc = createMockDoc({ name: '<script>test</script>.txt' })
|
||||
render(<DocumentTableRow {...defaultProps} doc={doc} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('<script>test</script>.txt')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should memoize the component', () => {
|
||||
const wrapper = createWrapper()
|
||||
const { rerender } = render(<DocumentTableRow {...defaultProps} />, { wrapper })
|
||||
|
||||
rerender(<DocumentTableRow {...defaultProps} />)
|
||||
expect(screen.getByRole('row')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,152 @@
|
||||
import type { FC } from 'react'
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { RiEditLine } from '@remixicon/react'
|
||||
import { pick } from 'es-toolkit/object'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label'
|
||||
import Operations from '@/app/components/datasets/documents/components/operations'
|
||||
import SummaryStatus from '@/app/components/datasets/documents/detail/completed/common/summary-status'
|
||||
import StatusItem from '@/app/components/datasets/documents/status-item'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import DocumentSourceIcon from './document-source-icon'
|
||||
import { renderTdValue } from './utils'
|
||||
|
||||
type LocalDoc = SimpleDocumentDetail & { percent?: number }
|
||||
|
||||
type DocumentTableRowProps = {
|
||||
doc: LocalDoc
|
||||
index: number
|
||||
datasetId: string
|
||||
isSelected: boolean
|
||||
isGeneralMode: boolean
|
||||
isQAMode: boolean
|
||||
embeddingAvailable: boolean
|
||||
selectedIds: string[]
|
||||
onSelectOne: (docId: string) => void
|
||||
onSelectedIdChange: (ids: string[]) => void
|
||||
onShowRenameModal: (doc: LocalDoc) => void
|
||||
onUpdate: () => void
|
||||
}
|
||||
|
||||
const renderCount = (count: number | undefined) => {
|
||||
if (!count)
|
||||
return renderTdValue(0, true)
|
||||
|
||||
if (count < 1000)
|
||||
return count
|
||||
|
||||
return `${formatNumber((count / 1000).toFixed(1))}k`
|
||||
}
|
||||
|
||||
const DocumentTableRow: FC<DocumentTableRowProps> = React.memo(({
|
||||
doc,
|
||||
index,
|
||||
datasetId,
|
||||
isSelected,
|
||||
isGeneralMode,
|
||||
isQAMode,
|
||||
embeddingAvailable,
|
||||
selectedIds,
|
||||
onSelectOne,
|
||||
onSelectedIdChange,
|
||||
onShowRenameModal,
|
||||
onUpdate,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { formatTime } = useTimestamp()
|
||||
const router = useRouter()
|
||||
|
||||
const isFile = doc.data_source_type === DataSourceType.FILE
|
||||
const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : ''
|
||||
|
||||
const handleRowClick = useCallback(() => {
|
||||
router.push(`/datasets/${datasetId}/documents/${doc.id}`)
|
||||
}, [router, datasetId, doc.id])
|
||||
|
||||
const handleCheckboxClick = useCallback((e: React.MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
}, [])
|
||||
|
||||
const handleRenameClick = useCallback((e: React.MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
onShowRenameModal(doc)
|
||||
}, [doc, onShowRenameModal])
|
||||
|
||||
return (
|
||||
<tr
|
||||
className="h-8 cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover"
|
||||
onClick={handleRowClick}
|
||||
>
|
||||
<td className="text-left align-middle text-xs text-text-tertiary">
|
||||
<div className="flex items-center" onClick={handleCheckboxClick}>
|
||||
<Checkbox
|
||||
className="mr-2 shrink-0"
|
||||
checked={isSelected}
|
||||
onCheck={() => onSelectOne(doc.id)}
|
||||
/>
|
||||
{index + 1}
|
||||
</div>
|
||||
</td>
|
||||
<td>
|
||||
<div className="group mr-6 flex max-w-[460px] items-center hover:mr-0">
|
||||
<div className="flex shrink-0 items-center">
|
||||
<DocumentSourceIcon doc={doc} fileType={fileType} />
|
||||
</div>
|
||||
<Tooltip popupContent={doc.name}>
|
||||
<span className="grow-1 truncate text-sm">{doc.name}</span>
|
||||
</Tooltip>
|
||||
{doc.summary_index_status && (
|
||||
<div className="ml-1 hidden shrink-0 group-hover:flex">
|
||||
<SummaryStatus status={doc.summary_index_status} />
|
||||
</div>
|
||||
)}
|
||||
<div className="hidden shrink-0 group-hover:ml-auto group-hover:flex">
|
||||
<Tooltip popupContent={t('list.table.rename', { ns: 'datasetDocuments' })}>
|
||||
<div
|
||||
className="cursor-pointer rounded-md p-1 hover:bg-state-base-hover"
|
||||
onClick={handleRenameClick}
|
||||
>
|
||||
<RiEditLine className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
<td>
|
||||
<ChunkingModeLabel
|
||||
isGeneralMode={isGeneralMode}
|
||||
isQAMode={isQAMode}
|
||||
/>
|
||||
</td>
|
||||
<td>{renderCount(doc.word_count)}</td>
|
||||
<td>{renderCount(doc.hit_count)}</td>
|
||||
<td className="text-[13px] text-text-secondary">
|
||||
{formatTime(doc.created_at, t('dateTimeFormat', { ns: 'datasetHitTesting' }) as string)}
|
||||
</td>
|
||||
<td>
|
||||
<StatusItem status={doc.display_status} />
|
||||
</td>
|
||||
<td>
|
||||
<Operations
|
||||
selectedIds={selectedIds}
|
||||
onSelectedIdChange={onSelectedIdChange}
|
||||
embeddingAvailable={embeddingAvailable}
|
||||
datasetId={datasetId}
|
||||
detail={pick(doc, ['name', 'enabled', 'archived', 'id', 'data_source_type', 'doc_form', 'display_status'])}
|
||||
onUpdate={onUpdate}
|
||||
/>
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})
|
||||
|
||||
DocumentTableRow.displayName = 'DocumentTableRow'
|
||||
|
||||
export default DocumentTableRow
|
||||
@ -0,0 +1,4 @@
|
||||
export { default as DocumentSourceIcon } from './document-source-icon'
|
||||
export { default as DocumentTableRow } from './document-table-row'
|
||||
export { default as SortHeader } from './sort-header'
|
||||
export { renderTdValue } from './utils'
|
||||
@ -0,0 +1,124 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import SortHeader from './sort-header'
|
||||
|
||||
describe('SortHeader', () => {
|
||||
const defaultProps = {
|
||||
field: 'name' as const,
|
||||
label: 'File Name',
|
||||
currentSortField: null,
|
||||
sortOrder: 'desc' as const,
|
||||
onSort: vi.fn(),
|
||||
}
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should render the label', () => {
|
||||
render(<SortHeader {...defaultProps} />)
|
||||
expect(screen.getByText('File Name')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the sort icon', () => {
|
||||
const { container } = render(<SortHeader {...defaultProps} />)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('inactive state', () => {
|
||||
it('should have disabled text color when not active', () => {
|
||||
const { container } = render(<SortHeader {...defaultProps} />)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).toHaveClass('text-text-disabled')
|
||||
})
|
||||
|
||||
it('should not be rotated when not active', () => {
|
||||
const { container } = render(<SortHeader {...defaultProps} />)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).not.toHaveClass('rotate-180')
|
||||
})
|
||||
})
|
||||
|
||||
describe('active state', () => {
|
||||
it('should have tertiary text color when active', () => {
|
||||
const { container } = render(
|
||||
<SortHeader {...defaultProps} currentSortField="name" />,
|
||||
)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).toHaveClass('text-text-tertiary')
|
||||
})
|
||||
|
||||
it('should not be rotated when active and desc', () => {
|
||||
const { container } = render(
|
||||
<SortHeader {...defaultProps} currentSortField="name" sortOrder="desc" />,
|
||||
)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).not.toHaveClass('rotate-180')
|
||||
})
|
||||
|
||||
it('should be rotated when active and asc', () => {
|
||||
const { container } = render(
|
||||
<SortHeader {...defaultProps} currentSortField="name" sortOrder="asc" />,
|
||||
)
|
||||
const icon = container.querySelector('svg')
|
||||
expect(icon).toHaveClass('rotate-180')
|
||||
})
|
||||
})
|
||||
|
||||
describe('interaction', () => {
|
||||
it('should call onSort when clicked', () => {
|
||||
const onSort = vi.fn()
|
||||
render(<SortHeader {...defaultProps} onSort={onSort} />)
|
||||
|
||||
fireEvent.click(screen.getByText('File Name'))
|
||||
|
||||
expect(onSort).toHaveBeenCalledWith('name')
|
||||
})
|
||||
|
||||
it('should call onSort with correct field', () => {
|
||||
const onSort = vi.fn()
|
||||
render(<SortHeader {...defaultProps} field="word_count" onSort={onSort} />)
|
||||
|
||||
fireEvent.click(screen.getByText('File Name'))
|
||||
|
||||
expect(onSort).toHaveBeenCalledWith('word_count')
|
||||
})
|
||||
})
|
||||
|
||||
describe('different fields', () => {
|
||||
it('should work with word_count field', () => {
|
||||
render(
|
||||
<SortHeader
|
||||
{...defaultProps}
|
||||
field="word_count"
|
||||
label="Words"
|
||||
currentSortField="word_count"
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('Words')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should work with hit_count field', () => {
|
||||
render(
|
||||
<SortHeader
|
||||
{...defaultProps}
|
||||
field="hit_count"
|
||||
label="Hit Count"
|
||||
currentSortField="hit_count"
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('Hit Count')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should work with created_at field', () => {
|
||||
render(
|
||||
<SortHeader
|
||||
{...defaultProps}
|
||||
field="created_at"
|
||||
label="Upload Time"
|
||||
currentSortField="created_at"
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('Upload Time')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,44 @@
|
||||
import type { FC } from 'react'
|
||||
import type { SortField, SortOrder } from '../hooks'
|
||||
import { RiArrowDownLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type SortHeaderProps = {
|
||||
field: Exclude<SortField, null>
|
||||
label: string
|
||||
currentSortField: SortField
|
||||
sortOrder: SortOrder
|
||||
onSort: (field: SortField) => void
|
||||
}
|
||||
|
||||
const SortHeader: FC<SortHeaderProps> = React.memo(({
|
||||
field,
|
||||
label,
|
||||
currentSortField,
|
||||
sortOrder,
|
||||
onSort,
|
||||
}) => {
|
||||
const isActive = currentSortField === field
|
||||
const isDesc = isActive && sortOrder === 'desc'
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex cursor-pointer items-center hover:text-text-secondary"
|
||||
onClick={() => onSort(field)}
|
||||
>
|
||||
{label}
|
||||
<RiArrowDownLine
|
||||
className={cn(
|
||||
'ml-0.5 h-3 w-3 transition-all',
|
||||
isActive ? 'text-text-tertiary' : 'text-text-disabled',
|
||||
isActive && !isDesc ? 'rotate-180' : '',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
SortHeader.displayName = 'SortHeader'
|
||||
|
||||
export default SortHeader
|
||||
@ -0,0 +1,90 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { renderTdValue } from './utils'
|
||||
|
||||
describe('renderTdValue', () => {
|
||||
describe('Rendering', () => {
|
||||
it('should render string value correctly', () => {
|
||||
const { container } = render(<>{renderTdValue('test value')}</>)
|
||||
expect(screen.getByText('test value')).toBeInTheDocument()
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-secondary')
|
||||
})
|
||||
|
||||
it('should render number value correctly', () => {
|
||||
const { container } = render(<>{renderTdValue(42)}</>)
|
||||
expect(screen.getByText('42')).toBeInTheDocument()
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-secondary')
|
||||
})
|
||||
|
||||
it('should render zero correctly', () => {
|
||||
const { container } = render(<>{renderTdValue(0)}</>)
|
||||
expect(screen.getByText('0')).toBeInTheDocument()
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-secondary')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Null and undefined handling', () => {
|
||||
it('should render dash for null value', () => {
|
||||
render(<>{renderTdValue(null)}</>)
|
||||
expect(screen.getByText('-')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render dash for null value with empty style', () => {
|
||||
const { container } = render(<>{renderTdValue(null, true)}</>)
|
||||
expect(screen.getByText('-')).toBeInTheDocument()
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-tertiary')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Empty style', () => {
|
||||
it('should apply text-text-tertiary class when isEmptyStyle is true', () => {
|
||||
const { container } = render(<>{renderTdValue('value', true)}</>)
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-tertiary')
|
||||
})
|
||||
|
||||
it('should apply text-text-secondary class when isEmptyStyle is false', () => {
|
||||
const { container } = render(<>{renderTdValue('value', false)}</>)
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-secondary')
|
||||
})
|
||||
|
||||
it('should apply text-text-secondary class when isEmptyStyle is not provided', () => {
|
||||
const { container } = render(<>{renderTdValue('value')}</>)
|
||||
expect(container.querySelector('div')).toHaveClass('text-text-secondary')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string', () => {
|
||||
render(<>{renderTdValue('')}</>)
|
||||
// Empty string should still render but with no visible text
|
||||
const div = document.querySelector('div')
|
||||
expect(div).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle large numbers', () => {
|
||||
render(<>{renderTdValue(1234567890)}</>)
|
||||
expect(screen.getByText('1234567890')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle negative numbers', () => {
|
||||
render(<>{renderTdValue(-42)}</>)
|
||||
expect(screen.getByText('-42')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle special characters in string', () => {
|
||||
render(<>{renderTdValue('<script>alert("xss")</script>')}</>)
|
||||
expect(screen.getByText('<script>alert("xss")</script>')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle unicode characters', () => {
|
||||
render(<>{renderTdValue('Test Unicode: \u4E2D\u6587')}</>)
|
||||
expect(screen.getByText('Test Unicode: \u4E2D\u6587')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle very long strings', () => {
|
||||
const longString = 'a'.repeat(1000)
|
||||
render(<>{renderTdValue(longString)}</>)
|
||||
expect(screen.getByText(longString)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,16 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import s from '../../../style.module.css'
|
||||
|
||||
export const renderTdValue = (value: string | number | null, isEmptyStyle = false): ReactNode => {
|
||||
const className = cn(
|
||||
isEmptyStyle ? 'text-text-tertiary' : 'text-text-secondary',
|
||||
s.tdValue,
|
||||
)
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
{value ?? '-'}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
export { useDocumentActions } from './use-document-actions'
|
||||
export { useDocumentSelection } from './use-document-selection'
|
||||
export { useDocumentSort } from './use-document-sort'
|
||||
export type { SortField, SortOrder } from './use-document-sort'
|
||||
@ -0,0 +1,438 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { DocumentActionType } from '@/models/datasets'
|
||||
import * as useDocument from '@/service/knowledge/use-document'
|
||||
import { useDocumentActions } from './use-document-actions'
|
||||
|
||||
vi.mock('@/service/knowledge/use-document')
|
||||
|
||||
const mockUseDocumentArchive = vi.mocked(useDocument.useDocumentArchive)
|
||||
const mockUseDocumentSummary = vi.mocked(useDocument.useDocumentSummary)
|
||||
const mockUseDocumentEnable = vi.mocked(useDocument.useDocumentEnable)
|
||||
const mockUseDocumentDisable = vi.mocked(useDocument.useDocumentDisable)
|
||||
const mockUseDocumentDelete = vi.mocked(useDocument.useDocumentDelete)
|
||||
const mockUseDocumentBatchRetryIndex = vi.mocked(useDocument.useDocumentBatchRetryIndex)
|
||||
const mockUseDocumentDownloadZip = vi.mocked(useDocument.useDocumentDownloadZip)
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('useDocumentActions', () => {
|
||||
const mockMutateAsync = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Setup all mocks with default values
|
||||
const createMockMutation = () => ({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
isSuccess: false,
|
||||
isIdle: true,
|
||||
data: undefined,
|
||||
error: null,
|
||||
mutate: vi.fn(),
|
||||
reset: vi.fn(),
|
||||
status: 'idle' as const,
|
||||
variables: undefined,
|
||||
context: undefined,
|
||||
failureCount: 0,
|
||||
failureReason: null,
|
||||
submittedAt: 0,
|
||||
})
|
||||
|
||||
mockUseDocumentArchive.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentArchive>)
|
||||
mockUseDocumentSummary.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentSummary>)
|
||||
mockUseDocumentEnable.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentEnable>)
|
||||
mockUseDocumentDisable.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentDisable>)
|
||||
mockUseDocumentDelete.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentDelete>)
|
||||
mockUseDocumentBatchRetryIndex.mockReturnValue(createMockMutation() as unknown as ReturnType<typeof useDocument.useDocumentBatchRetryIndex>)
|
||||
mockUseDocumentDownloadZip.mockReturnValue({
|
||||
...createMockMutation(),
|
||||
isPending: false,
|
||||
} as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>)
|
||||
})
|
||||
|
||||
describe('handleAction', () => {
|
||||
it('should call archive mutation when archive action is triggered', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleAction(DocumentActionType.archive)()
|
||||
})
|
||||
|
||||
expect(mockMutateAsync).toHaveBeenCalledWith({
|
||||
datasetId: 'ds1',
|
||||
documentIds: ['doc1'],
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onUpdate on successful action', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleAction(DocumentActionType.enable)()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onUpdate).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onClearSelection on delete action', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleAction(DocumentActionType.delete)()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onClearSelection).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleBatchReIndex', () => {
|
||||
it('should call retry index mutation', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchReIndex()
|
||||
})
|
||||
|
||||
expect(mockMutateAsync).toHaveBeenCalledWith({
|
||||
datasetId: 'ds1',
|
||||
documentIds: ['doc1', 'doc2'],
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onClearSelection on success', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchReIndex()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onClearSelection).toHaveBeenCalled()
|
||||
expect(onUpdate).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleBatchDownload', () => {
|
||||
it('should not proceed when already downloading', async () => {
|
||||
mockUseDocumentDownloadZip.mockReturnValue({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: true,
|
||||
} as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: ['doc1'],
|
||||
onUpdate: vi.fn(),
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchDownload()
|
||||
})
|
||||
|
||||
expect(mockMutateAsync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call download mutation with downloadable ids', async () => {
|
||||
const mockBlob = new Blob(['test'])
|
||||
mockMutateAsync.mockResolvedValue(mockBlob)
|
||||
|
||||
mockUseDocumentDownloadZip.mockReturnValue({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: false,
|
||||
} as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
downloadableSelectedIds: ['doc1'],
|
||||
onUpdate: vi.fn(),
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchDownload()
|
||||
})
|
||||
|
||||
expect(mockMutateAsync).toHaveBeenCalledWith({
|
||||
datasetId: 'ds1',
|
||||
documentIds: ['doc1'],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isDownloadingZip', () => {
|
||||
it('should reflect isPending state from mutation', () => {
|
||||
mockUseDocumentDownloadZip.mockReturnValue({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: true,
|
||||
} as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: [],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate: vi.fn(),
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
expect(result.current.isDownloadingZip).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should show error toast when handleAction fails', async () => {
|
||||
mockMutateAsync.mockRejectedValue(new Error('Action failed'))
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleAction(DocumentActionType.archive)()
|
||||
})
|
||||
|
||||
// onUpdate should not be called on error
|
||||
expect(onUpdate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show error toast when handleBatchReIndex fails', async () => {
|
||||
mockMutateAsync.mockRejectedValue(new Error('Re-index failed'))
|
||||
const onUpdate = vi.fn()
|
||||
const onClearSelection = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchReIndex()
|
||||
})
|
||||
|
||||
// onUpdate and onClearSelection should not be called on error
|
||||
expect(onUpdate).not.toHaveBeenCalled()
|
||||
expect(onClearSelection).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show error toast when handleBatchDownload fails', async () => {
|
||||
mockMutateAsync.mockRejectedValue(new Error('Download failed'))
|
||||
|
||||
mockUseDocumentDownloadZip.mockReturnValue({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: false,
|
||||
} as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: ['doc1'],
|
||||
onUpdate: vi.fn(),
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchDownload()
|
||||
})
|
||||
|
||||
// Mutation was called but failed
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show error toast when handleBatchDownload returns null blob', async () => {
|
||||
mockMutateAsync.mockResolvedValue(null)
|
||||
|
||||
mockUseDocumentDownloadZip.mockReturnValue({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: false,
|
||||
} as unknown as ReturnType<typeof useDocument.useDocumentDownloadZip>)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: ['doc1'],
|
||||
onUpdate: vi.fn(),
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleBatchDownload()
|
||||
})
|
||||
|
||||
// Mutation was called but returned null
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('all action types', () => {
|
||||
it('should handle summary action', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleAction(DocumentActionType.summary)()
|
||||
})
|
||||
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
await waitFor(() => {
|
||||
expect(onUpdate).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle disable action', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'success' })
|
||||
const onUpdate = vi.fn()
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDocumentActions({
|
||||
datasetId: 'ds1',
|
||||
selectedIds: ['doc1'],
|
||||
downloadableSelectedIds: [],
|
||||
onUpdate,
|
||||
onClearSelection: vi.fn(),
|
||||
}),
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleAction(DocumentActionType.disable)()
|
||||
})
|
||||
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
await waitFor(() => {
|
||||
expect(onUpdate).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,126 @@
|
||||
import type { CommonResponse } from '@/models/common'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { DocumentActionType } from '@/models/datasets'
|
||||
import {
|
||||
useDocumentArchive,
|
||||
useDocumentBatchRetryIndex,
|
||||
useDocumentDelete,
|
||||
useDocumentDisable,
|
||||
useDocumentDownloadZip,
|
||||
useDocumentEnable,
|
||||
useDocumentSummary,
|
||||
} from '@/service/knowledge/use-document'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
|
||||
type UseDocumentActionsOptions = {
|
||||
datasetId: string
|
||||
selectedIds: string[]
|
||||
downloadableSelectedIds: string[]
|
||||
onUpdate: () => void
|
||||
onClearSelection: () => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a random ZIP filename for bulk document downloads.
|
||||
* We intentionally avoid leaking dataset info in the exported archive name.
|
||||
*/
|
||||
const generateDocsZipFileName = (): string => {
|
||||
const randomPart = (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function')
|
||||
? crypto.randomUUID()
|
||||
: `${Date.now().toString(36)}${Math.random().toString(36).slice(2, 10)}`
|
||||
return `${randomPart}-docs.zip`
|
||||
}
|
||||
|
||||
export const useDocumentActions = ({
|
||||
datasetId,
|
||||
selectedIds,
|
||||
downloadableSelectedIds,
|
||||
onUpdate,
|
||||
onClearSelection,
|
||||
}: UseDocumentActionsOptions) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { mutateAsync: archiveDocument } = useDocumentArchive()
|
||||
const { mutateAsync: generateSummary } = useDocumentSummary()
|
||||
const { mutateAsync: enableDocument } = useDocumentEnable()
|
||||
const { mutateAsync: disableDocument } = useDocumentDisable()
|
||||
const { mutateAsync: deleteDocument } = useDocumentDelete()
|
||||
const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex()
|
||||
const { mutateAsync: requestDocumentsZip, isPending: isDownloadingZip } = useDocumentDownloadZip()
|
||||
|
||||
type SupportedActionType
|
||||
= | typeof DocumentActionType.archive
|
||||
| typeof DocumentActionType.summary
|
||||
| typeof DocumentActionType.enable
|
||||
| typeof DocumentActionType.disable
|
||||
| typeof DocumentActionType.delete
|
||||
|
||||
const actionMutationMap = useMemo(() => ({
|
||||
[DocumentActionType.archive]: archiveDocument,
|
||||
[DocumentActionType.summary]: generateSummary,
|
||||
[DocumentActionType.enable]: enableDocument,
|
||||
[DocumentActionType.disable]: disableDocument,
|
||||
[DocumentActionType.delete]: deleteDocument,
|
||||
} as const), [archiveDocument, generateSummary, enableDocument, disableDocument, deleteDocument])
|
||||
|
||||
const handleAction = useCallback((actionName: SupportedActionType) => {
|
||||
return async () => {
|
||||
const opApi = actionMutationMap[actionName]
|
||||
if (!opApi)
|
||||
return
|
||||
|
||||
const [e] = await asyncRunSafe<CommonResponse>(
|
||||
opApi({ datasetId, documentIds: selectedIds }),
|
||||
)
|
||||
|
||||
if (!e) {
|
||||
if (actionName === DocumentActionType.delete)
|
||||
onClearSelection()
|
||||
Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
onUpdate()
|
||||
}
|
||||
else {
|
||||
Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
|
||||
}
|
||||
}
|
||||
}, [actionMutationMap, datasetId, selectedIds, onClearSelection, onUpdate, t])
|
||||
|
||||
const handleBatchReIndex = useCallback(async () => {
|
||||
const [e] = await asyncRunSafe<CommonResponse>(
|
||||
retryIndexDocument({ datasetId, documentIds: selectedIds }),
|
||||
)
|
||||
if (!e) {
|
||||
onClearSelection()
|
||||
Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
onUpdate()
|
||||
}
|
||||
else {
|
||||
Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
|
||||
}
|
||||
}, [retryIndexDocument, datasetId, selectedIds, onClearSelection, onUpdate, t])
|
||||
|
||||
const handleBatchDownload = useCallback(async () => {
|
||||
if (isDownloadingZip)
|
||||
return
|
||||
|
||||
const [e, blob] = await asyncRunSafe(
|
||||
requestDocumentsZip({ datasetId, documentIds: downloadableSelectedIds }),
|
||||
)
|
||||
if (e || !blob) {
|
||||
Toast.notify({ type: 'error', message: t('actionMsg.downloadUnsuccessfully', { ns: 'common' }) })
|
||||
return
|
||||
}
|
||||
|
||||
downloadBlob({ data: blob, fileName: generateDocsZipFileName() })
|
||||
}, [datasetId, downloadableSelectedIds, isDownloadingZip, requestDocumentsZip, t])
|
||||
|
||||
return {
|
||||
handleAction,
|
||||
handleBatchReIndex,
|
||||
handleBatchDownload,
|
||||
isDownloadingZip,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,317 @@
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { useDocumentSelection } from './use-document-selection'
|
||||
|
||||
type LocalDoc = SimpleDocumentDetail & { percent?: number }
|
||||
|
||||
const createMockDocument = (overrides: Partial<LocalDoc> = {}): LocalDoc => ({
|
||||
id: 'doc1',
|
||||
name: 'Test Document',
|
||||
data_source_type: DataSourceType.FILE,
|
||||
data_source_info: {},
|
||||
data_source_detail_dict: {},
|
||||
word_count: 100,
|
||||
hit_count: 10,
|
||||
created_at: 1000000,
|
||||
position: 1,
|
||||
doc_form: 'text_model',
|
||||
enabled: true,
|
||||
archived: false,
|
||||
display_status: 'available',
|
||||
created_from: 'api',
|
||||
...overrides,
|
||||
} as LocalDoc)
|
||||
|
||||
describe('useDocumentSelection', () => {
|
||||
describe('isAllSelected', () => {
|
||||
it('should return false when documents is empty', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: [],
|
||||
selectedIds: [],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.isAllSelected).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when all documents are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1' }),
|
||||
createMockDocument({ id: 'doc2' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.isAllSelected).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when not all documents are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1' }),
|
||||
createMockDocument({ id: 'doc2' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.isAllSelected).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isSomeSelected', () => {
|
||||
it('should return false when no documents are selected', () => {
|
||||
const docs = [createMockDocument({ id: 'doc1' })]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: [],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.isSomeSelected).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when some documents are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1' }),
|
||||
createMockDocument({ id: 'doc2' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.isSomeSelected).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('onSelectAll', () => {
|
||||
it('should select all documents when none are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1' }),
|
||||
createMockDocument({ id: 'doc2' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: [],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.onSelectAll()
|
||||
})
|
||||
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1', 'doc2'])
|
||||
})
|
||||
|
||||
it('should deselect all when all are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1' }),
|
||||
createMockDocument({ id: 'doc2' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.onSelectAll()
|
||||
})
|
||||
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith([])
|
||||
})
|
||||
|
||||
it('should add to existing selection when some are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1' }),
|
||||
createMockDocument({ id: 'doc2' }),
|
||||
createMockDocument({ id: 'doc3' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.onSelectAll()
|
||||
})
|
||||
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1', 'doc2', 'doc3'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('onSelectOne', () => {
|
||||
it('should add document to selection when not selected', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: [],
|
||||
selectedIds: [],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.onSelectOne('doc1')
|
||||
})
|
||||
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1'])
|
||||
})
|
||||
|
||||
it('should remove document from selection when already selected', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: [],
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.onSelectOne('doc1')
|
||||
})
|
||||
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith(['doc2'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasErrorDocumentsSelected', () => {
|
||||
it('should return false when no error documents are selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', display_status: 'available' }),
|
||||
createMockDocument({ id: 'doc2', display_status: 'error' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.hasErrorDocumentsSelected).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when an error document is selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', display_status: 'available' }),
|
||||
createMockDocument({ id: 'doc2', display_status: 'error' }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc2'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.hasErrorDocumentsSelected).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadableSelectedIds', () => {
|
||||
it('should return only FILE type documents from selection', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', data_source_type: DataSourceType.FILE }),
|
||||
createMockDocument({ id: 'doc2', data_source_type: DataSourceType.NOTION }),
|
||||
createMockDocument({ id: 'doc3', data_source_type: DataSourceType.FILE }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1', 'doc2', 'doc3'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.downloadableSelectedIds).toEqual(['doc1', 'doc3'])
|
||||
})
|
||||
|
||||
it('should return empty array when no FILE documents selected', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', data_source_type: DataSourceType.NOTION }),
|
||||
createMockDocument({ id: 'doc2', data_source_type: DataSourceType.WEB }),
|
||||
]
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: docs,
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.downloadableSelectedIds).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearSelection', () => {
|
||||
it('should call onSelectedIdChange with empty array', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSelection({
|
||||
documents: [],
|
||||
selectedIds: ['doc1', 'doc2'],
|
||||
onSelectedIdChange,
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.clearSelection()
|
||||
})
|
||||
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith([])
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,66 @@
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { uniq } from 'es-toolkit/array'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
|
||||
type LocalDoc = SimpleDocumentDetail & { percent?: number }
|
||||
|
||||
type UseDocumentSelectionOptions = {
|
||||
documents: LocalDoc[]
|
||||
selectedIds: string[]
|
||||
onSelectedIdChange: (selectedIds: string[]) => void
|
||||
}
|
||||
|
||||
export const useDocumentSelection = ({
|
||||
documents,
|
||||
selectedIds,
|
||||
onSelectedIdChange,
|
||||
}: UseDocumentSelectionOptions) => {
|
||||
const isAllSelected = useMemo(() => {
|
||||
return documents.length > 0 && documents.every(doc => selectedIds.includes(doc.id))
|
||||
}, [documents, selectedIds])
|
||||
|
||||
const isSomeSelected = useMemo(() => {
|
||||
return documents.some(doc => selectedIds.includes(doc.id))
|
||||
}, [documents, selectedIds])
|
||||
|
||||
const onSelectAll = useCallback(() => {
|
||||
if (isAllSelected)
|
||||
onSelectedIdChange([])
|
||||
else
|
||||
onSelectedIdChange(uniq([...selectedIds, ...documents.map(doc => doc.id)]))
|
||||
}, [isAllSelected, documents, onSelectedIdChange, selectedIds])
|
||||
|
||||
const onSelectOne = useCallback((docId: string) => {
|
||||
onSelectedIdChange(
|
||||
selectedIds.includes(docId)
|
||||
? selectedIds.filter(id => id !== docId)
|
||||
: [...selectedIds, docId],
|
||||
)
|
||||
}, [selectedIds, onSelectedIdChange])
|
||||
|
||||
const hasErrorDocumentsSelected = useMemo(() => {
|
||||
return documents.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error')
|
||||
}, [documents, selectedIds])
|
||||
|
||||
const downloadableSelectedIds = useMemo(() => {
|
||||
const selectedSet = new Set(selectedIds)
|
||||
return documents
|
||||
.filter(doc => selectedSet.has(doc.id) && doc.data_source_type === DataSourceType.FILE)
|
||||
.map(doc => doc.id)
|
||||
}, [documents, selectedIds])
|
||||
|
||||
const clearSelection = useCallback(() => {
|
||||
onSelectedIdChange([])
|
||||
}, [onSelectedIdChange])
|
||||
|
||||
return {
|
||||
isAllSelected,
|
||||
isSomeSelected,
|
||||
onSelectAll,
|
||||
onSelectOne,
|
||||
hasErrorDocumentsSelected,
|
||||
downloadableSelectedIds,
|
||||
clearSelection,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,340 @@
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { useDocumentSort } from './use-document-sort'
|
||||
|
||||
type LocalDoc = SimpleDocumentDetail & { percent?: number }
|
||||
|
||||
const createMockDocument = (overrides: Partial<LocalDoc> = {}): LocalDoc => ({
|
||||
id: 'doc1',
|
||||
name: 'Test Document',
|
||||
data_source_type: 'upload_file',
|
||||
data_source_info: {},
|
||||
data_source_detail_dict: {},
|
||||
word_count: 100,
|
||||
hit_count: 10,
|
||||
created_at: 1000000,
|
||||
position: 1,
|
||||
doc_form: 'text_model',
|
||||
enabled: true,
|
||||
archived: false,
|
||||
display_status: 'available',
|
||||
created_from: 'api',
|
||||
...overrides,
|
||||
} as LocalDoc)
|
||||
|
||||
describe('useDocumentSort', () => {
|
||||
describe('initial state', () => {
|
||||
it('should return null sortField initially', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.sortField).toBeNull()
|
||||
expect(result.current.sortOrder).toBe('desc')
|
||||
})
|
||||
|
||||
it('should return documents unchanged when no sort is applied', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', name: 'B' }),
|
||||
createMockDocument({ id: 'doc2', name: 'A' }),
|
||||
]
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.sortedDocuments).toEqual(docs)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleSort', () => {
|
||||
it('should set sort field when called', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
|
||||
expect(result.current.sortField).toBe('name')
|
||||
expect(result.current.sortOrder).toBe('desc')
|
||||
})
|
||||
|
||||
it('should toggle sort order when same field is clicked twice', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
expect(result.current.sortOrder).toBe('desc')
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
expect(result.current.sortOrder).toBe('asc')
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
expect(result.current.sortOrder).toBe('desc')
|
||||
})
|
||||
|
||||
it('should reset to desc when different field is selected', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
expect(result.current.sortOrder).toBe('asc')
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('word_count')
|
||||
})
|
||||
expect(result.current.sortField).toBe('word_count')
|
||||
expect(result.current.sortOrder).toBe('desc')
|
||||
})
|
||||
|
||||
it('should not change state when null is passed', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort(null)
|
||||
})
|
||||
|
||||
expect(result.current.sortField).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('sorting documents', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', name: 'Banana', word_count: 200, hit_count: 5, created_at: 3000 }),
|
||||
createMockDocument({ id: 'doc2', name: 'Apple', word_count: 100, hit_count: 10, created_at: 1000 }),
|
||||
createMockDocument({ id: 'doc3', name: 'Cherry', word_count: 300, hit_count: 1, created_at: 2000 }),
|
||||
]
|
||||
|
||||
it('should sort by name descending', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
|
||||
const names = result.current.sortedDocuments.map(d => d.name)
|
||||
expect(names).toEqual(['Cherry', 'Banana', 'Apple'])
|
||||
})
|
||||
|
||||
it('should sort by name ascending', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
|
||||
const names = result.current.sortedDocuments.map(d => d.name)
|
||||
expect(names).toEqual(['Apple', 'Banana', 'Cherry'])
|
||||
})
|
||||
|
||||
it('should sort by word_count descending', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('word_count')
|
||||
})
|
||||
|
||||
const counts = result.current.sortedDocuments.map(d => d.word_count)
|
||||
expect(counts).toEqual([300, 200, 100])
|
||||
})
|
||||
|
||||
it('should sort by hit_count ascending', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('hit_count')
|
||||
})
|
||||
act(() => {
|
||||
result.current.handleSort('hit_count')
|
||||
})
|
||||
|
||||
const counts = result.current.sortedDocuments.map(d => d.hit_count)
|
||||
expect(counts).toEqual([1, 5, 10])
|
||||
})
|
||||
|
||||
it('should sort by created_at descending', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('created_at')
|
||||
})
|
||||
|
||||
const times = result.current.sortedDocuments.map(d => d.created_at)
|
||||
expect(times).toEqual([3000, 2000, 1000])
|
||||
})
|
||||
})
|
||||
|
||||
describe('status filtering', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', display_status: 'available' }),
|
||||
createMockDocument({ id: 'doc2', display_status: 'error' }),
|
||||
createMockDocument({ id: 'doc3', display_status: 'available' }),
|
||||
]
|
||||
|
||||
it('should not filter when statusFilterValue is empty', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.sortedDocuments.length).toBe(3)
|
||||
})
|
||||
|
||||
it('should not filter when statusFilterValue is all', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: 'all',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(result.current.sortedDocuments.length).toBe(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('remoteSortValue reset', () => {
|
||||
it('should reset sort state when remoteSortValue changes', () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ remoteSortValue }) =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue,
|
||||
}),
|
||||
{ initialProps: { remoteSortValue: 'initial' } },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
expect(result.current.sortField).toBe('name')
|
||||
expect(result.current.sortOrder).toBe('asc')
|
||||
|
||||
rerender({ remoteSortValue: 'changed' })
|
||||
|
||||
expect(result.current.sortField).toBeNull()
|
||||
expect(result.current.sortOrder).toBe('desc')
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle documents with missing values', () => {
|
||||
const docs = [
|
||||
createMockDocument({ id: 'doc1', name: undefined as unknown as string, word_count: undefined }),
|
||||
createMockDocument({ id: 'doc2', name: 'Test', word_count: 100 }),
|
||||
]
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: docs,
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
|
||||
expect(result.current.sortedDocuments.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should handle empty documents array', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useDocumentSort({
|
||||
documents: [],
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}),
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.handleSort('name')
|
||||
})
|
||||
|
||||
expect(result.current.sortedDocuments).toEqual([])
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,102 @@
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
|
||||
|
||||
export type SortField = 'name' | 'word_count' | 'hit_count' | 'created_at' | null
|
||||
export type SortOrder = 'asc' | 'desc'
|
||||
|
||||
type LocalDoc = SimpleDocumentDetail & { percent?: number }
|
||||
|
||||
type UseDocumentSortOptions = {
|
||||
documents: LocalDoc[]
|
||||
statusFilterValue: string
|
||||
remoteSortValue: string
|
||||
}
|
||||
|
||||
export const useDocumentSort = ({
|
||||
documents,
|
||||
statusFilterValue,
|
||||
remoteSortValue,
|
||||
}: UseDocumentSortOptions) => {
|
||||
const [sortField, setSortField] = useState<SortField>(null)
|
||||
const [sortOrder, setSortOrder] = useState<SortOrder>('desc')
|
||||
const prevRemoteSortValueRef = useRef(remoteSortValue)
|
||||
|
||||
// Reset sort when remote sort changes
|
||||
if (prevRemoteSortValueRef.current !== remoteSortValue) {
|
||||
prevRemoteSortValueRef.current = remoteSortValue
|
||||
setSortField(null)
|
||||
setSortOrder('desc')
|
||||
}
|
||||
|
||||
const handleSort = useCallback((field: SortField) => {
|
||||
if (field === null)
|
||||
return
|
||||
|
||||
if (sortField === field) {
|
||||
setSortOrder(prev => prev === 'asc' ? 'desc' : 'asc')
|
||||
}
|
||||
else {
|
||||
setSortField(field)
|
||||
setSortOrder('desc')
|
||||
}
|
||||
}, [sortField])
|
||||
|
||||
const sortedDocuments = useMemo(() => {
|
||||
let filteredDocs = documents
|
||||
|
||||
if (statusFilterValue && statusFilterValue !== 'all') {
|
||||
filteredDocs = filteredDocs.filter(doc =>
|
||||
typeof doc.display_status === 'string'
|
||||
&& normalizeStatusForQuery(doc.display_status) === statusFilterValue,
|
||||
)
|
||||
}
|
||||
|
||||
if (!sortField)
|
||||
return filteredDocs
|
||||
|
||||
const sortedDocs = [...filteredDocs].sort((a, b) => {
|
||||
let aValue: string | number
|
||||
let bValue: string | number
|
||||
|
||||
switch (sortField) {
|
||||
case 'name':
|
||||
aValue = a.name?.toLowerCase() || ''
|
||||
bValue = b.name?.toLowerCase() || ''
|
||||
break
|
||||
case 'word_count':
|
||||
aValue = a.word_count || 0
|
||||
bValue = b.word_count || 0
|
||||
break
|
||||
case 'hit_count':
|
||||
aValue = a.hit_count || 0
|
||||
bValue = b.hit_count || 0
|
||||
break
|
||||
case 'created_at':
|
||||
aValue = a.created_at
|
||||
bValue = b.created_at
|
||||
break
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
if (sortField === 'name') {
|
||||
const result = (aValue as string).localeCompare(bValue as string)
|
||||
return sortOrder === 'asc' ? result : -result
|
||||
}
|
||||
else {
|
||||
const result = (aValue as number) - (bValue as number)
|
||||
return sortOrder === 'asc' ? result : -result
|
||||
}
|
||||
})
|
||||
|
||||
return sortedDocs
|
||||
}, [documents, sortField, sortOrder, statusFilterValue])
|
||||
|
||||
return {
|
||||
sortField,
|
||||
sortOrder,
|
||||
handleSort,
|
||||
sortedDocuments,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,487 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { Props as PaginationProps } from '@/app/components/base/pagination'
|
||||
import type { SimpleDocumentDetail } from '@/models/datasets'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ChunkingMode, DataSourceType } from '@/models/datasets'
|
||||
import DocumentList from '../list'
|
||||
|
||||
const mockPush = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContextWithSelector: (selector: (state: { dataset: { doc_form: string } }) => unknown) =>
|
||||
selector({ dataset: { doc_form: ChunkingMode.text } }),
|
||||
}))
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false, gcTime: 0 },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
const createMockDoc = (overrides: Partial<SimpleDocumentDetail> = {}): SimpleDocumentDetail => ({
|
||||
id: `doc-${Math.random().toString(36).substr(2, 9)}`,
|
||||
position: 1,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
data_source_info: {},
|
||||
data_source_detail_dict: {
|
||||
upload_file: { name: 'test.txt', extension: 'txt' },
|
||||
},
|
||||
dataset_process_rule_id: 'rule-1',
|
||||
batch: 'batch-1',
|
||||
name: 'test-document.txt',
|
||||
created_from: 'web',
|
||||
created_by: 'user-1',
|
||||
created_at: Date.now(),
|
||||
tokens: 100,
|
||||
indexing_status: 'completed',
|
||||
error: null,
|
||||
enabled: true,
|
||||
disabled_at: null,
|
||||
disabled_by: null,
|
||||
archived: false,
|
||||
archived_reason: null,
|
||||
archived_by: null,
|
||||
archived_at: null,
|
||||
updated_at: Date.now(),
|
||||
doc_type: null,
|
||||
doc_metadata: undefined,
|
||||
display_status: 'available',
|
||||
word_count: 500,
|
||||
hit_count: 10,
|
||||
doc_form: 'text_model',
|
||||
...overrides,
|
||||
} as SimpleDocumentDetail)
|
||||
|
||||
const defaultPagination: PaginationProps = {
|
||||
current: 1,
|
||||
onChange: vi.fn(),
|
||||
total: 100,
|
||||
}
|
||||
|
||||
describe('DocumentList', () => {
|
||||
const defaultProps = {
|
||||
embeddingAvailable: true,
|
||||
documents: [
|
||||
createMockDoc({ id: 'doc-1', name: 'Document 1.txt', word_count: 100, hit_count: 5 }),
|
||||
createMockDoc({ id: 'doc-2', name: 'Document 2.txt', word_count: 200, hit_count: 10 }),
|
||||
createMockDoc({ id: 'doc-3', name: 'Document 3.txt', word_count: 300, hit_count: 15 }),
|
||||
],
|
||||
selectedIds: [] as string[],
|
||||
onSelectedIdChange: vi.fn(),
|
||||
datasetId: 'dataset-1',
|
||||
pagination: defaultPagination,
|
||||
onUpdate: vi.fn(),
|
||||
onManageMetadata: vi.fn(),
|
||||
statusFilterValue: '',
|
||||
remoteSortValue: '',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render all documents', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('Document 1.txt')).toBeInTheDocument()
|
||||
expect(screen.getByText('Document 2.txt')).toBeInTheDocument()
|
||||
expect(screen.getByText('Document 3.txt')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render table headers', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByText('#')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render pagination when total is provided', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
// Pagination component should be present
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render pagination when total is 0', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
pagination: { ...defaultPagination, total: 0 },
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render empty table when no documents', () => {
|
||||
const props = { ...defaultProps, documents: [] }
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Selection', () => {
|
||||
// Helper to find checkboxes (custom div components, not native checkboxes)
|
||||
const findCheckboxes = (container: HTMLElement): NodeListOf<Element> => {
|
||||
return container.querySelectorAll('[class*="shadow-xs"]')
|
||||
}
|
||||
|
||||
it('should render header checkbox when embeddingAvailable', () => {
|
||||
const { container } = render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
const checkboxes = findCheckboxes(container)
|
||||
expect(checkboxes.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should not render header checkbox when embedding not available', () => {
|
||||
const props = { ...defaultProps, embeddingAvailable: false }
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
// Row checkboxes should still be there, but header checkbox should be hidden
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSelectedIdChange when select all is clicked', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
const props = { ...defaultProps, onSelectedIdChange }
|
||||
const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
const checkboxes = findCheckboxes(container)
|
||||
if (checkboxes.length > 0) {
|
||||
fireEvent.click(checkboxes[0])
|
||||
expect(onSelectedIdChange).toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
|
||||
it('should show all checkboxes as checked when all are selected', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1', 'doc-2', 'doc-3'],
|
||||
}
|
||||
const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
const checkboxes = findCheckboxes(container)
|
||||
// When checked, checkbox should have a check icon (svg) inside
|
||||
checkboxes.forEach((checkbox) => {
|
||||
const checkIcon = checkbox.querySelector('svg')
|
||||
expect(checkIcon).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show indeterminate state when some are selected', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
}
|
||||
const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// First checkbox is the header checkbox which should be indeterminate
|
||||
const checkboxes = findCheckboxes(container)
|
||||
expect(checkboxes.length).toBeGreaterThan(0)
|
||||
// Header checkbox should show indeterminate icon, not check icon
|
||||
// Just verify it's rendered
|
||||
expect(checkboxes[0]).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSelectedIdChange with single document when row checkbox is clicked', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
const props = { ...defaultProps, onSelectedIdChange }
|
||||
const { container } = render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// Click the second checkbox (first row checkbox)
|
||||
const checkboxes = findCheckboxes(container)
|
||||
if (checkboxes.length > 1) {
|
||||
fireEvent.click(checkboxes[1])
|
||||
expect(onSelectedIdChange).toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Sorting', () => {
|
||||
it('should render sort headers for sortable columns', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
// Find svg icons which indicate sortable columns
|
||||
const sortIcons = document.querySelectorAll('svg')
|
||||
expect(sortIcons.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should update sort order when sort header is clicked', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
// Find and click a sort header by its parent div containing the label text
|
||||
const sortableHeaders = document.querySelectorAll('[class*="cursor-pointer"]')
|
||||
if (sortableHeaders.length > 0) {
|
||||
fireEvent.click(sortableHeaders[0])
|
||||
}
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Batch Actions', () => {
|
||||
it('should show batch action bar when documents are selected', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1', 'doc-2'],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// BatchAction component should be visible
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show batch action bar when no documents selected', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
// BatchAction should not be present
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render batch action bar with archive option', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// BatchAction component should be visible when documents are selected
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render batch action bar with enable option', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render batch action bar with disable option', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render batch action bar with delete option', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should clear selection when cancel is clicked', () => {
|
||||
const onSelectedIdChange = vi.fn()
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
onSelectedIdChange,
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
const cancelButton = screen.queryByRole('button', { name: /cancel/i })
|
||||
if (cancelButton) {
|
||||
fireEvent.click(cancelButton)
|
||||
expect(onSelectedIdChange).toHaveBeenCalledWith([])
|
||||
}
|
||||
})
|
||||
|
||||
it('should show download option for downloadable documents', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
documents: [
|
||||
createMockDoc({ id: 'doc-1', data_source_type: DataSourceType.FILE }),
|
||||
],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// BatchAction should be visible
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show re-index option for error documents', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
documents: [
|
||||
createMockDoc({ id: 'doc-1', display_status: 'error' }),
|
||||
],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// BatchAction with re-index should be present for error documents
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Row Click Navigation', () => {
|
||||
it('should navigate to document detail when row is clicked', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
const rows = screen.getAllByRole('row')
|
||||
// First row is header, second row is first document
|
||||
if (rows.length > 1) {
|
||||
fireEvent.click(rows[1])
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rename Modal', () => {
|
||||
it('should not show rename modal initially', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
// RenameModal should not be visible initially
|
||||
const modal = screen.queryByRole('dialog')
|
||||
expect(modal).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show rename modal when rename button is clicked', () => {
|
||||
const { container } = render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
|
||||
// Find and click the rename button in the first row
|
||||
const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md')
|
||||
if (renameButtons.length > 0) {
|
||||
fireEvent.click(renameButtons[0])
|
||||
}
|
||||
|
||||
// After clicking rename, the modal should potentially be visible
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onUpdate when document is renamed', () => {
|
||||
const onUpdate = vi.fn()
|
||||
const props = { ...defaultProps, onUpdate }
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// The handleRenamed callback wraps onUpdate
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edit Metadata Modal', () => {
|
||||
it('should handle edit metadata action', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
const editButton = screen.queryByRole('button', { name: /metadata/i })
|
||||
if (editButton) {
|
||||
fireEvent.click(editButton)
|
||||
}
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onManageMetadata when manage metadata is triggered', () => {
|
||||
const onManageMetadata = vi.fn()
|
||||
const props = {
|
||||
...defaultProps,
|
||||
selectedIds: ['doc-1'],
|
||||
onManageMetadata,
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
// The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Chunking Mode', () => {
|
||||
it('should render with general mode', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with QA mode', () => {
|
||||
// This test uses the default mock which returns ChunkingMode.text
|
||||
// The component will compute isQAMode based on doc_form
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with parent-child mode', () => {
|
||||
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty documents array', () => {
|
||||
const props = { ...defaultProps, documents: [] }
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle documents with missing optional fields', () => {
|
||||
const docWithMissingFields = createMockDoc({
|
||||
word_count: undefined as unknown as number,
|
||||
hit_count: undefined as unknown as number,
|
||||
})
|
||||
const props = {
|
||||
...defaultProps,
|
||||
documents: [docWithMissingFields],
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle status filter value', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
statusFilterValue: 'completed',
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle remote sort value', () => {
|
||||
const props = {
|
||||
...defaultProps,
|
||||
remoteSortValue: 'created_at',
|
||||
}
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle large number of documents', () => {
|
||||
const manyDocs = Array.from({ length: 20 }, (_, i) =>
|
||||
createMockDoc({ id: `doc-${i}`, name: `Document ${i}.txt` }))
|
||||
const props = { ...defaultProps, documents: manyDocs }
|
||||
render(<DocumentList {...props} />, { wrapper: createWrapper() })
|
||||
|
||||
expect(screen.getByRole('table')).toBeInTheDocument()
|
||||
}, 10000)
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user