diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 106c26bbed..bfb1c85436 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,9 @@ # CODEOWNERS file /.github/CODEOWNERS @laipz8200 @crazywoola +# Agents +/.agents/skills/ @hyoban + # Docs /docs/ @crazywoola @@ -21,6 +24,10 @@ /api/services/tools/mcp_tools_manage_service.py @Nov1c444 /api/controllers/mcp/ @Nov1c444 /api/controllers/console/app/mcp_server.py @Nov1c444 + +# Backend - Tests +/api/tests/ @laipz8200 @QuantumGhost + /api/tests/**/*mcp* @Nov1c444 # Backend - Workflow - Engine (Core graph execution engine) @@ -231,6 +238,9 @@ # Frontend - Base Components /web/app/components/base/ @iamjoel @zxhlyh +# Frontend - Base Components Tests +/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar + # Frontend - Utils and Hooks /web/utils/classnames.ts @iamjoel @zxhlyh /web/utils/time.ts @iamjoel @zxhlyh diff --git a/api/commands.py b/api/commands.py index 4b811fb1e6..c4f2c9edbb 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool): all_ids_in_tables = [] for ids_table in ids_tables: query = "" - if ids_table["type"] == "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) ) - ) - query = ( - f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - elif ids_table["type"] == "text": - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", - fg="white", + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - elif ids_table["type"] == "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) except Exception as e: @@ -1737,59 +1741,18 @@ def file_usage( if src_filter != src: continue - if ids_table["type"] == "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - elif ids_table["type"] in ("text", "json"): - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) if ref_file_id not in file_key_map: continue storage_key = file_key_map[ref_file_id] @@ -1812,6 +1775,50 @@ def file_usage( ) total_count += 1 + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + # Output results if output_json: result = { diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 6a4c1528b0..9931bb5dd7 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,10 +1,11 @@ from typing import Any, Literal from flask import abort, make_response, request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -16,9 +17,11 @@ from controllers.console.wraps import ( ) from extensions.ext_redis import redis_client from fields.annotation_fields import ( - annotation_fields, - annotation_hit_history_fields, - build_annotation_model, + Annotation, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, + AnnotationList, ) from libs.helper import uuid_value from libs.login import login_required @@ -89,6 +92,14 @@ reg(CreateAnnotationPayload) reg(UpdateAnnotationPayload) reg(AnnotationReplyStatusQuery) reg(AnnotationFilePayload) +register_schema_models( + console_ns, + Annotation, + AnnotationList, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, +) @console_ns.route("/apps//annotation-reply/") @@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) args = AnnotationReplyPayload.model_validate(console_ns.payload) - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -201,33 +213,33 @@ class AnnotationApi(Resource): app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response, 200 + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json"), 200 @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) - @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) + @console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) @edit_permission_required def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) data = args.model_dump(exclude_none=True) annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -264,7 +276,7 @@ class AnnotationExportApi(Resource): @console_ns.response( 200, "Annotations exported successfully", - console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}), + console_ns.models[AnnotationExportList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -274,7 +286,8 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response_data = {"data": marshal(annotation_list, annotation_fields)} + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") # Create response with secure headers for CSV export response = make_response(response_data, 200) @@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @console_ns.doc(description="Update or delete an annotation") @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) - @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) + @console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__]) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - @marshal_with(annotation_fields) def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) @@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly( args.model_dump(exclude_none=True), app_id, annotation_id ) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource): @console_ns.response( 200, "Hit histories retrieved successfully", - console_ns.model( - "AnnotationHitHistoryList", - { - "data": fields.List( - fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields)) - ) - }, - ), + console_ns.models[AnnotationHitHistoryList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource): annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( app_id, annotation_id, page, limit ) - response = { - "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), - "has_more": len(annotation_hit_history_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response + history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( + annotation_hit_history_list, from_attributes=True + ) + response = AnnotationHitHistoryList( + data=history_models, + has_more=len(annotation_hit_history_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d344ede466..941db325bf 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -33,7 +34,6 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class TextToSpeechPayload(BaseModel): @@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel): language: str = Field(..., description="Language code") -console_ns.schema_model( - TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechVoiceQuery.__name__, - TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AudioTranscriptResponse(BaseModel): + text: str = Field(description="Transcribed text from audio") + + +register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery) @console_ns.route("/apps//audio-to-text") @@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource): @console_ns.response( 200, "Audio transcription successful", - console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + console_ns.models[AudioTranscriptResponse.__name__], ) @console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(413, "Audio file too large") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 14910c5895..c8b4e83ae6 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -509,16 +509,19 @@ class ChatConversationApi(Resource): case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ) - elif args.annotation_status == "not_annotated": - query = ( - query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) - .group_by(Conversation.id) - .having(func.count(MessageAnnotation.id) == 0) - ) + match args.annotation_status: + case "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + case "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) + case "all": + pass if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index ab1628d5d4..0bea777870 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft from services.message_service import MessageService, attach_message_extra_contents logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ChatMessagesQuery(BaseModel): @@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel): raise ValueError("has_comment must be a boolean value") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class AnnotationCountResponse(BaseModel): + count: int = Field(description="Number of annotations") -reg(ChatMessagesQuery) -reg(MessageFeedbackPayload) -reg(FeedbackExportQuery) +class SuggestedQuestionsResponse(BaseModel): + data: list[str] = Field(description="Suggested question") + + +register_schema_models( + console_ns, + ChatMessagesQuery, + MessageFeedbackPayload, + FeedbackExportQuery, + AnnotationCountResponse, + SuggestedQuestionsResponse, +) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -232,7 +241,7 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = ( db.session.query(Conversation) @@ -358,7 +367,7 @@ class MessageAnnotationCountApi(Resource): @console_ns.response( 200, "Annotation count retrieved successfully", - console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + console_ns.models[AnnotationCountResponse.__name__], ) @get_app_model @setup_required @@ -378,9 +387,7 @@ class MessageSuggestedQuestionApi(Resource): @console_ns.response( 200, "Suggested questions retrieved successfully", - console_ns.model( - "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))} - ), + console_ns.models[SuggestedQuestionsResponse.__name__], ) @console_ns.response(404, "Message or conversation not found") @setup_required @@ -430,7 +437,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function from services.feedback_service import FeedbackService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0dd7d33ae9..3a3278ec9d 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,9 +2,11 @@ import logging import httpx from flask import current_app, redirect, request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field from configs import dify_config +from controllers.common.schema import register_schema_models from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required, logger = logging.getLogger(__name__) +class OAuthDataSourceResponse(BaseModel): + data: str = Field(description="Authorization URL or 'internal' for internal setup") + + +class OAuthDataSourceBindingResponse(BaseModel): + result: str = Field(description="Operation result") + + +class OAuthDataSourceSyncResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + OAuthDataSourceResponse, + OAuthDataSourceBindingResponse, + OAuthDataSourceSyncResponse, +) + + def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( @@ -34,10 +56,7 @@ class OAuthDataSource(Resource): @console_ns.response( 200, "Authorization URL or internal setup success", - console_ns.model( - "OAuthDataSourceResponse", - {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, - ), + console_ns.models[OAuthDataSourceResponse.__name__], ) @console_ns.response(400, "Invalid provider") @console_ns.response(403, "Admin privileges required") @@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource): @console_ns.response( 200, "Data source binding success", - console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceBindingResponse.__name__], ) @console_ns.response(400, "Invalid provider or code") def get(self, provider: str): @@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource): @console_ns.response( 200, "Data source sync success", - console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceSyncResponse.__name__], ) @console_ns.response(400, "Invalid provider or sync failed") @setup_required diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 394f205d93..1ed931b0d7 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,10 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) -for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ForgotPasswordEmailResponse(BaseModel): + result: str = Field(description="Operation result") + data: str | None = Field(default=None, description="Reset token") + code: str | None = Field(default=None, description="Error code if account not found") + + +class ForgotPasswordCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether code is valid") + email: EmailStr = Field(description="Email address") + token: str = Field(description="New reset token") + + +class ForgotPasswordResetResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + ForgotPasswordSendPayload, + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordEmailResponse, + ForgotPasswordCheckResponse, + ForgotPasswordResetResponse, +) @console_ns.route("/forgot-password") @@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): @console_ns.response( 200, "Email sent successfully", - console_ns.model( - "ForgotPasswordEmailResponse", - { - "result": fields.String(description="Operation result"), - "data": fields.String(description="Reset token"), - "code": fields.String(description="Error code if account not found"), - }, - ), + console_ns.models[ForgotPasswordEmailResponse.__name__], ) @console_ns.response(400, "Invalid email or rate limit exceeded") @setup_required @@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource): @console_ns.response( 200, "Code verified successfully", - console_ns.model( - "ForgotPasswordCheckResponse", - { - "is_valid": fields.Boolean(description="Whether code is valid"), - "email": fields.String(description="Email address"), - "token": fields.String(description="New reset token"), - }, - ), + console_ns.models[ForgotPasswordCheckResponse.__name__], ) @console_ns.response(400, "Invalid code or token") @setup_required @@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource): @console_ns.response( 200, "Password reset successfully", - console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[ForgotPasswordResetResponse.__name__], ) @console_ns.response(400, "Invalid token or password mismatch") @setup_required diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6162d88a0b..38ea5d2dae 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource): grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + if not payload.code: + raise BadRequest("code is required") - if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not payload.code: - raise BadRequest("code is required") + if payload.client_secret != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") - if payload.client_secret != oauth_provider_app.client_secret: - raise BadRequest("client_secret is invalid") + if payload.redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") - if payload.redirect_uri not in oauth_provider_app.redirect_uris: - raise BadRequest("redirect_uri is invalid") + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=payload.code, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + case OAuthGrantType.REFRESH_TOKEN: + if not payload.refresh_token: + raise BadRequest("refresh_token is required") - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=payload.code, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) - elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not payload.refresh_token: - raise BadRequest("refresh_token is required") - - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) @console_ns.route("/oauth/provider/account") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01e9bf77c0..daef4e005a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator -from typing import Any, cast +from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -157,9 +157,8 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, binding_id, action): + def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - action = str(action) with Session(db.engine) as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) @@ -167,23 +166,24 @@ class DataSourceApi(Resource): if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding - if action == "enable": - if data_source_binding.disabled: - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is not disabled.") - # disable binding - if action == "disable": - if not data_source_binding.disabled: - data_source_binding.disabled = True - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is disabled.") + match action: + case "enable": + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is not disabled.") + # disable binding + case "disable": + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is disabled.") return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6e3c0db8a3..bf097d374a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict + match document.data_source_type: + case "upload_file": + if not data_source_info: + continue + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) + .first() + ) - if document.data_source_type == "upload_file": - if not data_source_info: - continue - file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() - ) + if file_detail is None: + raise NotFound("File not found.") - if file_detail is None: - raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form + ) + extract_settings.append(extract_setting) + case "notion_import": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) + case "website_crawl": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) - extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form - ) - extract_settings.append(extract_setting) - - elif document.data_source_type == "notion_import": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_tenant_id, - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - elif document.data_source_type == "website_crawl": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( @@ -954,23 +953,24 @@ class DocumentProcessingApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() - if action == "pause": - if document.indexing_status != "indexing": - raise InvalidActionError("Document not in indexing state.") + match action: + case "pause": + if document.indexing_status != "indexing": + raise InvalidActionError("Document not in indexing state.") - document.paused_by = current_user.id - document.paused_at = naive_utc_now() - document.is_paused = True - db.session.commit() + document.paused_by = current_user.id + document.paused_at = naive_utc_now() + document.is_paused = True + db.session.commit() - elif action == "resume": - if document.indexing_status not in {"paused", "error"}: - raise InvalidActionError("Document not in paused or error state.") + case "resume": + if document.indexing_status not in {"paused", "error"}: + raise InvalidActionError("Document not in paused or error state.") - document.paused_by = None - document.paused_at = None - document.is_paused = False - db.session.commit() + document.paused_by = None + document.paused_at = None + document.is_paused = False + db.session.commit() return {"result": "success"}, 200 @@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource): missing_ids = set(document_list) - found_ids raise NotFound(f"Some documents not found: {list(missing_ids)}") + # Update need_summary to True for documents that don't have it set + # This handles the case where documents were created when summary_index_setting was disabled + documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"] + + if documents_to_update: + document_ids_to_update = [str(doc.id) for doc in documents_to_update] + DocumentService.update_documents_need_summary( + dataset_id=dataset_id, + document_ids=document_ids_to_update, + need_summary=True, + ) + # Dispatch async tasks for each document for document in documents: # Skip qa_model documents as they don't generate summaries diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 05fc4cd714..2e69ddc5ab 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1eb0cdb019..c417967c88 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -1,8 +1,9 @@ import logging -from typing import Any, cast +from typing import Any, Literal, cast from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -51,7 +52,7 @@ from fields.app_fields import ( tag_fields, ) from fields.dataset_fields import dataset_fields -from fields.member_fields import build_simple_account_model +from fields.member_fields import simple_account_fields from fields.workflow_fields import ( conversation_variable_fields, pipeline_variable_fields, @@ -103,7 +104,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) -simple_account_model = build_simple_account_model(console_ns) +simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) @@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) +# Pydantic models for request validation +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowRunRequest(BaseModel): + inputs: dict + files: list | None = None + + +class ChatRequest(BaseModel): + inputs: dict + query: str + files: list | None = None + conversation_id: str | None = None + parent_message_id: str | None = None + retriever_from: str = "explore_app" + + +class TextToSpeechRequest(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + +class CompletionRequest(BaseModel): + inputs: dict + query: str = "" + files: list | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = "explore_app" + + +# Register schemas for Swagger documentation +console_ns.schema_model( + WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + + class TrialAppWorkflowRunApi(TrialAppResource): + @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) def post(self, trial_app): """ Run workflow @@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - args = parser.parse_args() + request_data = WorkflowRunRequest.model_validate(console_ns.payload) + args = request_data.model_dump() assert current_user is not None try: app_id = app_model.id @@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource): class TrialChatApi(TrialAppResource): + @console_ns.expect(console_ns.models[ChatRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app @@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - args = parser.parse_args() + request_data = ChatRequest.model_validate(console_ns.payload) + args = request_data.model_dump() + + # Validate UUID values if provided + if args.get("conversation_id"): + args["conversation_id"] = uuid_value(args["conversation_id"]) + if args.get("parent_message_id"): + args["parent_message_id"] = uuid_value(args["parent_message_id"]) args["auto_generate_name"] = False @@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource): class TrialChatTextApi(TrialAppResource): + @console_ns.expect(console_ns.models[TextToSpeechRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") - args = parser.parse_args() + request_data = TextToSpeechRequest.model_validate(console_ns.payload) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = request_data.message_id + text = request_data.text + voice = request_data.voice if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") @@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource): class TrialCompletionApi(TrialAppResource): + @console_ns.expect(console_ns.models[CompletionRequest.__name__]) @trial_feature_enable def post(self, trial_app): app_model = trial_app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - args = parser.parse_args() + request_data = CompletionRequest.model_validate(console_ns.payload) + args = request_data.model_dump() streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 1e98d622fe..d3811e2d1b 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,58 +1,60 @@ -from pydantic import BaseModel, Field +from flask_restx import Resource, fields from werkzeug.exceptions import Unauthorized -from controllers.fastopenapi import console_router from libs.login import current_account_with_tenant, current_user, login_required -from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel +from services.feature_service import FeatureService +from . import console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required -class FeatureResponse(BaseModel): - features: FeatureModel = Field(description="Feature configuration object") +@console_ns.route("/features") +class FeatureApi(Resource): + @console_ns.doc("get_tenant_features") + @console_ns.doc(description="Get feature configuration for current tenant") + @console_ns.response( + 200, + "Success", + console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) + @setup_required + @login_required + @account_initialization_required + @cloud_utm_record + def get(self): + """Get feature configuration for current tenant""" + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_features(current_tenant_id).model_dump() -class SystemFeatureResponse(BaseModel): - features: SystemFeatureModel = Field(description="System feature configuration object") +@console_ns.route("/system-features") +class SystemFeatureApi(Resource): + @console_ns.doc("get_system_features") + @console_ns.doc(description="Get system-wide feature configuration") + @console_ns.response( + 200, + "Success", + console_ns.model( + "SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")} + ), + ) + def get(self): + """Get system-wide feature configuration + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for dashboard initialization. -@console_router.get( - "/features", - response_model=FeatureResponse, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -@cloud_utm_record -def get_tenant_features() -> FeatureResponse: - """Get feature configuration for current tenant.""" - _, current_tenant_id = current_account_with_tenant() + Authentication would create circular dependency (can't login without dashboard loading). - return FeatureResponse(features=FeatureService.get_features(current_tenant_id)) - - -@console_router.get( - "/system-features", - response_model=SystemFeatureResponse, - tags=["console"], -) -def get_system_features() -> SystemFeatureResponse: - """Get system-wide feature configuration - - NOTE: This endpoint is unauthenticated by design, as it provides system features - data required for dashboard initialization. - - Authentication would create circular dependency (can't login without dashboard loading). - - Only non-sensitive configuration data should be returned by this endpoint. - """ - # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` - # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` - # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will - # raise `Unauthorized` exception if authentication token is not provided. - try: - is_authenticated = current_user.is_authenticated - except Unauthorized: - is_authenticated = False - return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated)) + Only non-sensitive configuration data should be returned by this endpoint. + """ + # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` + # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` + # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will + # raise `Unauthorized` exception if authentication token is not provided. + try: + is_authenticated = current_user.is_authenticated + except Unauthorized: + is_authenticated = False + return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index e828d54ff4..bc0776f658 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,14 +1,27 @@ from typing import Literal -from uuid import UUID +from flask import request +from flask_restx import Namespace, Resource, fields, marshal_with from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required -from controllers.fastopenapi import console_router from libs.login import current_account_with_tenant, login_required from services.tag_service import TagService +dataset_tag_fields = { + "id": fields.String, + "name": fields.String, + "type": fields.String, + "binding_count": fields.String, +} + + +def build_dataset_tag_fields(api_or_ns: Namespace): + return api_or_ns.model("DataSetTag", dataset_tag_fields) + class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) @@ -32,129 +45,115 @@ class TagListQueryParam(BaseModel): keyword: str | None = Field(None, description="Search keyword") -class TagResponse(BaseModel): - id: str = Field(description="Tag ID") - name: str = Field(description="Tag name") - type: str = Field(description="Tag type") - binding_count: int = Field(description="Number of bindings") - - -class TagBindingResult(BaseModel): - result: Literal["success"] = Field(description="Operation result", examples=["success"]) - - -@console_router.get( - "/tags", - response_model=list[TagResponse], - tags=["console"], +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, + TagListQueryParam, ) -@setup_required -@login_required -@account_initialization_required -def list_tags(query: TagListQueryParam) -> list[TagResponse]: - _, current_tenant_id = current_account_with_tenant() - tags = TagService.get_tags(query.type, current_tenant_id, query.keyword) - - return [ - TagResponse( - id=tag.id, - name=tag.name, - type=tag.type, - binding_count=int(tag.binding_count), - ) - for tag in tags - ] -@console_router.post( - "/tags", - response_model=TagResponse, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -def create_tag(payload: TagBasePayload) -> TagResponse: - current_user, _ = current_account_with_tenant() - # The role of the current user in the tag table must be admin, owner, or editor - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_ns.route("/tags") +class TagListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @console_ns.doc( + params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} + ) + @marshal_with(dataset_tag_fields) + def get(self): + _, current_tenant_id = current_account_with_tenant() + raw_args = request.args.to_dict() + param = TagListQueryParam.model_validate(raw_args) + tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) - tag = TagService.save_tags(payload.model_dump()) + return tags, 200 - return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def post(self): + current_user, _ = current_account_with_tenant() + # The role of the current user in the ta table must be admin, owner, or editor + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() + + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) + + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + + return response, 200 -@console_router.patch( - "/tags/", - response_model=TagResponse, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse: - current_user, _ = current_account_with_tenant() - tag_id_str = str(tag_id) - # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_ns.route("/tags/") +class TagUpdateDeleteApi(Resource): + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def patch(self, tag_id): + current_user, _ = current_account_with_tenant() + tag_id = str(tag_id) + # The role of the current user in the ta table must be admin, owner, or editor + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - tag = TagService.update_tags(payload.model_dump(), tag_id_str) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) - binding_count = TagService.get_tag_binding_count(tag_id_str) + binding_count = TagService.get_tag_binding_count(tag_id) - return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count) + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + + return response, 200 + + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def delete(self, tag_id): + tag_id = str(tag_id) + + TagService.delete_tag(tag_id) + + return 204 -@console_router.delete( - "/tags/", - tags=["console"], - status_code=204, -) -@setup_required -@login_required -@account_initialization_required -@edit_permission_required -def delete_tag(tag_id: UUID) -> None: - tag_id_str = str(tag_id) +@console_ns.route("/tag-bindings/create") +class TagBindingCreateApi(Resource): + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def post(self): + current_user, _ = current_account_with_tenant() + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - TagService.delete_tag(tag_id_str) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) + + return {"result": "success"}, 200 -@console_router.post( - "/tag-bindings/create", - response_model=TagBindingResult, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult: - current_user, _ = current_account_with_tenant() - # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_ns.route("/tag-bindings/remove") +class TagBindingDeleteApi(Resource): + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + def post(self): + current_user, _ = current_account_with_tenant() + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - TagService.save_tag_binding(payload.model_dump()) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) - return TagBindingResult(result="success") - - -@console_router.post( - "/tag-bindings/remove", - response_model=TagBindingResult, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult: - current_user, _ = current_account_with_tenant() - # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - TagService.delete_tag_binding(payload.model_dump()) - - return TagBindingResult(result="success") + return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 38c66525b3..708df62642 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -37,7 +38,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_fields +from fields.member_fields import Account as AccountResponse from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required @@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload) reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +register_schema_models(console_ns, AccountResponse) + + +def _serialize_account(account) -> dict: + return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") + integrate_fields = { "provider": fields.String, @@ -236,11 +243,11 @@ class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) @enterprise_license_required def get(self): current_user, _ = current_account_with_tenant() - return current_user + return _serialize_account(current_user) @console_ns.route("/account/name") @@ -249,14 +256,14 @@ class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) updated_account = AccountService.update_account(current_user, name=args.name) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/avatar") @@ -265,7 +272,7 @@ class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -273,7 +280,7 @@ class AccountAvatarApi(Resource): updated_account = AccountService.update_account(current_user, avatar=args.avatar) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-language") @@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource): updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-theme") @@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource): updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/timezone") @@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource): updated_account = AccountService.update_account(current_user, timezone=args.timezone) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/password") @@ -333,7 +340,7 @@ class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -344,7 +351,7 @@ class AccountPasswordApi(Resource): except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() - return {"result": "success"} + return _serialize_account(current_user) @console_ns.route("/account/integrates") @@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) @@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource): email=normalized_new_email, ) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/change-email/check-email-unique") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index bfd9fc6c29..1897cbdca7 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,9 +1,10 @@ from typing import Any from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery): plugin_id: str +class EndpointCreateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class PluginEndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class EndpointDeleteResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointUpdateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointEnableResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointDisableResponse(BaseModel): + success: bool = Field(description="Operation success") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -reg(EndpointCreatePayload) -reg(EndpointIdPayload) -reg(EndpointUpdatePayload) -reg(EndpointListQuery) -reg(EndpointListForPluginQuery) +register_schema_models( + console_ns, + EndpointCreatePayload, + EndpointIdPayload, + EndpointUpdatePayload, + EndpointListQuery, + EndpointListForPluginQuery, + EndpointCreateResponse, + EndpointListResponse, + PluginEndpointListResponse, + EndpointDeleteResponse, + EndpointUpdateResponse, + EndpointEnableResponse, + EndpointDisableResponse, +) @console_ns.route("/workspaces/current/endpoints/create") @@ -57,7 +96,7 @@ class EndpointCreateApi(Resource): @console_ns.response( 200, "Endpoint created successfully", - console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointCreateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -91,9 +130,7 @@ class EndpointListApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[EndpointListResponse.__name__], ) @setup_required @login_required @@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[PluginEndpointListResponse.__name__], ) @setup_required @login_required @@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource): @console_ns.response( 200, "Endpoint deleted successfully", - console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDeleteResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource): @console_ns.response( 200, "Endpoint updated successfully", - console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointUpdateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -221,7 +256,7 @@ class EndpointEnableApi(Resource): @console_ns.response( 200, "Endpoint enabled successfully", - console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointEnableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -248,7 +283,7 @@ class EndpointDisableApi(Resource): @console_ns.response( 200, "Endpoint disabled successfully", - console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDisableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 271cdce3c3..dd302b90d6 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,12 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter import services from configs import dify_config -from controllers.common.schema import get_or_create_model, register_enum_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -25,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_fields, account_with_role_list_fields +from fields.member_fields import AccountWithRole, AccountWithRoleList from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) - -account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) - -account_with_role_list_fields_copy = account_with_role_list_fields.copy() -account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) -account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) +register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) @console_ns.route("/workspaces/current/members") @@ -84,13 +79,15 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/invite-email") @@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 85ac9336d6..ef254ca357 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,16 +1,16 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource from flask_restx.api import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import annotation_fields, build_annotation_model +from fields.annotation_fields import Annotation, AnnotationList from models.model import App from services.annotation_service import AppAnnotationService @@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel): embedding_model_name: str = Field(description="Embedding model name") -register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) +register_schema_models( + service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList +) @service_api_ns.route("/apps/annotation-reply/") @@ -45,10 +47,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_model.id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args, app_model.id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -82,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -# Define annotation list response model -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), - "has_more": fields.Boolean, - "limit": fields.Integer, - "total": fields.Integer, - "page": fields.Integer, -} - - -def build_annotation_list_model(api_or_ns: Namespace): - """Build the annotation list model for the API or Namespace.""" - copied_annotation_list_fields = annotation_list_fields.copy() - copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) - return api_or_ns.model("AnnotationList", copied_annotation_list_fields) - - @service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): @service_api_ns.doc("list_annotations") @@ -109,8 +95,12 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + 200, + "Annotations retrieved successfully", + service_api_ns.models[AnnotationList.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): """List annotations for the application.""" page = request.args.get("page", default=1, type=int) @@ -118,13 +108,15 @@ class AnnotationListApi(Resource): keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - return { - "data": annotation_list, - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @@ -135,13 +127,18 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + HTTPStatus.CREATED, + "Annotation created successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation, 201 + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json"), HTTPStatus.CREATED @service_api_ns.route("/apps/annotations/") @@ -158,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource): 404: "Annotation not found", } ) + @service_api_ns.response( + 200, + "Annotation updated successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token @edit_permission_required - @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) - return annotation + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json") @service_api_ns.doc("delete_annotation") @service_api_ns.doc(description="Delete an annotation") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c11f64585a..db5cabe8aa 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import build_dataset_tag_fields +from fields.tag_fields import DataSetTag from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -114,6 +114,7 @@ register_schema_models( TagBindingPayload, TagUnbindingPayload, DatasetListQuery, + DataSetTag, ) @@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None tags = TagService.get_tags("knowledge", cid) - - return tags, 200 + tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True) + return [tag.model_dump(mode="json") for tag in tag_models], 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def post(self, _): """Add a knowledge type tag.""" assert isinstance(current_user, Account) @@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource): payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def patch(self, _): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index b8d9508004..692342a38a 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 24acced0d1..e597a72fc0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe # If caller needs end-user context, attach EndUser to current_user if fetch_user_arg: - if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get("user") - else: - user_id = None + user_id = None + match fetch_user_arg.fetch_from: + case WhereisUserArg.QUERY: + user_id = request.args.get("user") + case WhereisUserArg.JSON: + user_id = request.get_json().get("user") + case WhereisUserArg.FORM: + user_id = request.form.get("user") if not user_id and fetch_user_arg.required: raise ValueError("Arg user must be provided.") diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index c1f336fdde..9b981dfc09 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -14,16 +14,17 @@ class AgentConfigManager: agent_dict = config.get("agent_mode", {}) agent_strategy = agent_dict.get("strategy", "cot") - if agent_strategy == "function_call": - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy in {"cot", "react"}: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if config["model"]["provider"] == "openai": + match agent_strategy: + case "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: + case "cot" | "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + case _: + # old configs, try to detect default strategy + if config["model"]["provider"] == "openai": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] for tool in agent_dict.get("tools", []): diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6d329063f8..ba2f9b418a 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -268,7 +268,7 @@ class WorkflowResponseConverter: data=WorkflowFinishStreamResponse.Data( id=run_id, workflow_id=workflow_id, - status=status.value, + status=status, outputs=encoded_outputs, error=error, elapsed_time=elapsed_time, @@ -512,13 +512,13 @@ class WorkflowResponseConverter: metadata = self._merge_metadata(event.execution_metadata, snapshot) if isinstance(event, QueueNodeSucceededEvent): - status = WorkflowNodeExecutionStatus.SUCCEEDED.value + status = WorkflowNodeExecutionStatus.SUCCEEDED error_message = event.error elif isinstance(event, QueueNodeFailedEvent): - status = WorkflowNodeExecutionStatus.FAILED.value + status = WorkflowNodeExecutionStatus.FAILED error_message = event.error else: - status = WorkflowNodeExecutionStatus.EXCEPTION.value + status = WorkflowNodeExecutionStatus.EXCEPTION error_message = event.error return NodeFinishStreamResponse( @@ -585,7 +585,7 @@ class WorkflowResponseConverter: process_data_truncated=process_data_truncated, outputs=outputs, outputs_truncated=outputs_truncated, - status=WorkflowNodeExecutionStatus.RETRY.value, + status=WorkflowNodeExecutionStatus.RETRY, error=event.error, elapsed_time=elapsed_time, execution_metadata=metadata, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ea4441b5d8..eca96cb074 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] - datasource_type: str = args["datasource_type"] + datasource_type = DatasourceProviderType(args["datasource_type"]) datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user ) @@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator): tenant_id: str, dataset_id: str, built_in_field_enabled: bool, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info: Mapping[str, Any], created_from: str, position: int, @@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator): batch: str, document_form: str, ): - if datasource_type == "local_file": - name = datasource_info.get("name", "untitled") - elif datasource_type == "online_document": - name = datasource_info.get("page", {}).get("page_name", "untitled") - elif datasource_type == "website_crawl": - name = datasource_info.get("title", "untitled") - elif datasource_type == "online_drive": - name = datasource_info.get("name", "untitled") - else: - raise ValueError(f"Unsupported datasource type: {datasource_type}") - + match datasource_type: + case DatasourceProviderType.LOCAL_FILE: + name = datasource_info.get("name", "untitled") + case DatasourceProviderType.ONLINE_DOCUMENT: + name = datasource_info.get("page", {}).get("page_name", "untitled") + case DatasourceProviderType.WEBSITE_CRAWL: + name = datasource_info.get("title", "untitled") + case DatasourceProviderType.ONLINE_DRIVE: + name = datasource_info.get("name", "untitled") + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") document = Document( tenant_id=tenant_id, dataset_id=dataset_id, @@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator): def _format_datasource_info_list( self, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info_list: list[Mapping[str, Any]], pipeline: Pipeline, workflow: Workflow, @@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator): """ Format datasource info list. """ - if datasource_type == "online_drive": + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 3f38904d2f..b988ba677b 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.nodes.human_input.entities import FormInput, UserAction @@ -231,7 +231,7 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float @@ -398,7 +398,7 @@ class NodeFinishStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -462,7 +462,7 @@ class NodeRetryStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -806,7 +806,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index e172e88298..4e3ad7bb75 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -369,77 +369,78 @@ class IndexingRunner: # Generate summary preview summary_index_setting = tmp_processing_rule.get("summary_index_setting") if summary_index_setting and summary_index_setting.get("enable") and preview_texts: - preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + preview_texts = index_processor.generate_summary_preview( + tenant_id, preview_texts, summary_index_setting, doc_language + ) return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: - # load file - if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: - return [] - data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == "upload_file": - if not data_source_info or "upload_file_id" not in data_source_info: - raise ValueError("no upload file found") - stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) - file_detail = db.session.scalars(stmt).one_or_none() + match dataset_document.data_source_type: + case "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: + raise ValueError("no upload file found") + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() - if file_detail: + if file_detail: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): + raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_workspace_id" not in data_source_info - or "notion_page_id" not in data_source_info - ): - raise ValueError("no notion import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "website_crawl": - if ( - not data_source_info - or "provider" not in data_source_info - or "url" not in data_source_info - or "job_id" not in data_source_info - ): - raise ValueError("no website import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): + raise ValueError("no website import info found") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case _: + return [] # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index d46cf049dd..ee9a016c95 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = ( Requirements: 1. Write a concise summary in plain text -2. Use the same language as the input content +2. You must write in {language}. No language other than {language} should be used. 3. Focus on important facts, concepts, and details 4. If images are included, describe their key information 5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" 6. Write directly without extra words +7. If there is not enough content to generate a meaningful summary, + return an empty string without any explanation or prompt Output only the summary text. Start summarizing now: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 151a3de7d9..6e76321ea0 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC): @abstractmethod def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment in preview_texts, generate a summary using LLM and attach it to the segment. The summary can be stored in a new attribute, e.g., summary. This method should be implemented by subclasses. + + Args: + tenant_id: Tenant ID + preview_texts: List of preview details to generate summaries for + summary_index_setting: Summary index configuration + doc_language: Optional document language to ensure summary is generated in the correct language """ raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ab91e29145..41d7656f8a 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("Chunks is not a list") def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment, concurrently call generate_summary to generate a summary @@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if flask_app: # Ensure Flask app context in worker thread with flask_app.app_context(): - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary else: # Fallback: try without app context (may fail) - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary # Generate summaries concurrently using ThreadPoolExecutor @@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: str, summary_index_setting: dict | None = None, segment_id: str | None = None, + document_language: str | None = None, ) -> tuple[str, LLMUsage]: """ Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, @@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: Text content to summarize summary_index_setting: Summary index configuration segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + document_language: Optional document language (e.g., "Chinese", "English") + to ensure summary is generated in the correct language Returns: Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object @@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("model_name and model_provider_name are required in summary_index_setting") # Import default summary prompt + is_default_prompt = False if not summary_prompt: summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + is_default_prompt = True + + # Format prompt with document language only for default prompt + # Custom prompts are used as-is to avoid interfering with user-defined templates + # If document_language is provided, use it; otherwise, use "the same language as the input content" + # This is especially important for image-only chunks where text is empty or minimal + if is_default_prompt: + language_for_prompt = document_language or "the same language as the input content" + try: + summary_prompt = summary_prompt.format(language=language_for_prompt) + except KeyError: + # If default prompt doesn't have {language} placeholder, use it as-is + pass provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 961df2e50c..0ea77405ed 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary @@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary else: @@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 272d2ed351..40d9caaa69 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ QA model doesn't generate summaries, so this method returns preview_texts unchanged. diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5a365f769d..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]): result[parameter_name] = None continue agent_input = node_data.agent_parameters[parameter_name] - if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - elif agent_input.type in {"mixed", "constant"}: - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - else: - raise AgentInputTypeError(agent_input.type) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]): result: dict[str, Any] = {} for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] - if input.type in ["mixed", "constant"]: - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index fd71d610b4..a732a70417 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]): if typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters: input = typed_node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + case "constant": + pass + case None: + pass result = {node_id + "." + key: value for key, value in result.items()} @@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in { - DatasourceMessage.MessageType.IMAGE_LINK, - DatasourceMessage.MessageType.BINARY_LINK, - DatasourceMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceMessage.TextMessage) + match message.type: + case ( + DatasourceMessage.MessageType.IMAGE_LINK + | DatasourceMessage.MessageType.BINARY_LINK + | DatasourceMessage.MessageType.IMAGE + ): + assert isinstance(message.message, DatasourceMessage.TextMessage) - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE + url = message.message.text + transfer_method = FileTransferMethod.TOOL_FILE - datasource_file_id = str(url).split("/")[-1].split(".")[0] + datasource_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( mapping=mapping, tenant_id=self.tenant_id, ) - ) - elif message.type == DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - elif message.type == DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value + files.append(file) + case DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + case DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, + selector=[self._node_id, "text"], + chunk=message.message.text, is_final=False, ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) + case DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + json.append(message.message.json_object) + case DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) + case DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + case DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + case ( + DatasourceMessage.MessageType.BLOB_CHUNK + | DatasourceMessage.MessageType.LOG + | DatasourceMessage.MessageType.RETRIEVER_RESOURCES + ): + pass + # mark the end of the stream yield StreamChunkEvent( selector=[self._node_id, "text"], diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b88c2d510f..2aff953bc6 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): indexing_technique = node_data.indexing_technique or dataset.indexing_technique summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + # Try to get document language if document_id is available + doc_language = None + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document and document.doc_language: + doc_language = document.doc_language + outputs = self._get_preview_output_with_summaries( node_data.chunk_structure, chunks, dataset=dataset, indexing_technique=indexing_technique, summary_index_setting=summary_index_setting, + doc_language=doc_language, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset, indexing_technique: str | None = None, summary_index_setting: dict | None = None, + doc_language: str | None = None, ) -> Mapping[str, Any]: """ Generate preview output with summaries for chunks in preview mode. @@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset object (for tenant_id) indexing_technique: Indexing technique from node config or dataset summary_index_setting: Summary index setting from node config or dataset + doc_language: Optional document language to ensure summary is generated in the correct language """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() preview_output = index_processor.format_preview(chunks) @@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary @@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 3c4850ebac..0827494a48 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - if node_data.multiple_retrieval_config.reranking_model: - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } - else: + match node_data.multiple_retrieval_config.reranking_mode: + case "reranking_model": + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None + weights = None + case "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None - weights = None - elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": - if node_data.multiple_retrieval_config.weights is None: - raise ValueError("weights is required") - reranking_model = None - vector_setting = node_data.multiple_retrieval_config.weights.vector_setting - weights = { - "vector_setting": { - "vector_weight": vector_setting.vector_weight, - "embedding_provider_name": vector_setting.embedding_provider_name, - "embedding_model_name": vector_setting.embedding_model_name, - }, - "keyword_setting": { - "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - }, - } - else: - reranking_model = None - weights = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting + weights = { + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, + }, + "keyword_setting": { + "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight + }, + } + case _: + reranking_model = None + weights = None all_documents = dataset_retrieval.multiple_retrieve( app_id=self.app_id, tenant_id=self.tenant_id, @@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) filters: list[Any] = [] metadata_condition = None - if node_data.metadata_filtering_mode == "disabled": - return None, None, usage - elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( - dataset_ids, query, node_data - ) - usage = self._merge_usage(usage, automatic_usage) - if automatic_metadata_filters: - conditions = [] - for sequence, filter in enumerate(automatic_metadata_filters): - DatasetRetrieval.process_metadata_filter_func( - sequence, - filter.get("condition", ""), - filter.get("metadata_name", ""), - filter.get("value"), - filters, - ) - conditions.append( - Condition( - name=filter.get("metadata_name"), # type: ignore - comparison_operator=filter.get("condition"), # type: ignore - value=filter.get("value"), - ) - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator - if node_data.metadata_filtering_conditions - else "or", - conditions=conditions, + match node_data.metadata_filtering_mode: + case "disabled": + return None, None, usage + case "automatic": + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data ) - elif node_data.metadata_filtering_mode == "manual": - if node_data.metadata_filtering_conditions: - conditions = [] - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, + usage = self._merge_usage(usage, automatic_usage) + if automatic_metadata_filters: + conditions = [] + for sequence, filter in enumerate(automatic_metadata_filters): + DatasetRetrieval.process_metadata_filter_func( + sequence, + filter.get("condition", ""), + filter.get("metadata_name", ""), + filter.get("value"), + filters, ) + conditions.append( + Condition( + name=filter.get("metadata_name"), # type: ignore + comparison_operator=filter.get("condition"), # type: ignore + value=filter.get("value"), + ) + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator + if node_data.metadata_filtering_conditions + else "or", + conditions=conditions, ) - filters = DatasetRetrieval.process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + case "manual": + if node_data.metadata_filtering_conditions: + conditions = [] + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + metadata_name = condition.name + expected_value = condition.value + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).value[0] + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() + else: + raise ValueError("Invalid expected metadata value type") + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, + ) + ) + filters = DatasetRetrieval.process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator, + conditions=conditions, ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator, - conditions=conditions, - ) - else: - raise ValueError("Invalid metadata filtering mode") + case _: + raise ValueError("Invalid metadata filtering mode") if filters: if ( node_data.metadata_filtering_conditions diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 68ac60e4f6..60d76db9b6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]): result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + selector_key = ".".join(input.value) + result[f"#{selector_key}#"] = input.value + case "constant": + pass result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index c1608f58a5..18eed4e481 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage): """ content = self.load_once(filename) - with Path(target_filepath).open("wb") as f: - f.write(content) + Path(target_filepath).write_bytes(content) logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index e69306dcb2..a646950722 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,36 +1,69 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -annotation_fields = { - "id": fields.String, - "question": fields.String, - "answer": fields.Raw(attribute="content"), - "hit_count": fields.Integer, - "created_at": TimestampField, - # 'account': fields.Nested(simple_account_fields, allow_null=True) -} +from pydantic import BaseModel, ConfigDict, Field, field_validator -def build_annotation_model(api_or_ns: Namespace): - """Build the annotation model for the API or Namespace.""" - return api_or_ns.model("Annotation", annotation_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), -} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -annotation_hit_history_fields = { - "id": fields.String, - "source": fields.String, - "score": fields.Float, - "question": fields.String, - "created_at": TimestampField, - "match": fields.String(attribute="annotation_question"), - "response": fields.String(attribute="annotation_content"), -} -annotation_hit_history_list_fields = { - "data": fields.List(fields.Nested(annotation_hit_history_fields)), -} +class Annotation(ResponseModel): + id: str + question: str | None = None + answer: str | None = Field(default=None, validation_alias="content") + hit_count: int | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationList(ResponseModel): + data: list[Annotation] + has_more: bool + limit: int + total: int + page: int + + +class AnnotationExportList(ResponseModel): + data: list[Annotation] + + +class AnnotationHitHistory(ResponseModel): + id: str + source: str | None = None + score: float | None = None + question: str | None = None + created_at: int | None = None + match: str | None = Field(default=None, validation_alias="annotation_question") + response: str | None = Field(default=None, validation_alias="annotation_content") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationHitHistoryList(ResponseModel): + data: list[AnnotationHitHistory] + has_more: bool + limit: int + total: int + page: int diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 5389b0213a..effe7bfb20 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,7 @@ -from flask_restx import Namespace, fields +from __future__ import annotations + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict simple_end_user_fields = { "id": fields.String, @@ -8,5 +11,18 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleEndUser", simple_end_user_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleEndUser(ResponseModel): + id: str + type: str + is_anonymous: bool + session_id: str | None = None diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 25160927e6..11d9a1a2fc 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,6 +1,11 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import AvatarUrlField, TimestampField +from datetime import datetime + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict, computed_field, field_validator + +from core.file import helpers as file_helpers simple_account_fields = { "id": fields.String, @@ -9,36 +14,78 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleAccount", simple_account_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -account_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "is_password_set": fields.Boolean, - "interface_language": fields.String, - "interface_theme": fields.String, - "timezone": fields.String, - "last_login_at": TimestampField, - "last_login_ip": fields.String, - "created_at": TimestampField, -} +def _build_avatar_url(avatar: str | None) -> str | None: + if avatar is None: + return None + if avatar.startswith(("http://", "https://")): + return avatar + return file_helpers.get_signed_file_url(avatar) -account_with_role_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "last_login_at": TimestampField, - "last_active_at": TimestampField, - "created_at": TimestampField, - "role": fields.String, - "status": fields.String, -} -account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class _AccountAvatar(ResponseModel): + avatar: str | None = None + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def avatar_url(self) -> str | None: + return _build_avatar_url(self.avatar) + + +class Account(_AccountAvatar): + id: str + name: str + email: str + is_password_set: bool + interface_language: str | None = None + interface_theme: str | None = None + timezone: str | None = None + last_login_at: int | None = None + last_login_ip: str | None = None + created_at: int | None = None + + @field_validator("last_login_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRole(_AccountAvatar): + id: str + name: str + email: str + last_login_at: int | None = None + last_active_at: int | None = None + created_at: int | None = None + role: str + status: str + + @field_validator("last_login_at", "last_active_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRoleList(ResponseModel): + accounts: list[AccountWithRole] diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index e359a4408c..7cb64e5ca8 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,12 +1,20 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} +from pydantic import BaseModel, ConfigDict -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class DataSetTag(ResponseModel): + id: str + name: str + type: str + binding_count: str | None = None diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index ae70356322..d0e762f62b 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,7 +1,7 @@ from flask_restx import Namespace, fields -from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields -from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, @@ -25,17 +25,9 @@ workflow_app_log_partial_fields = { def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_app_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowAppLogPartial", copied_fields) @@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = { def build_workflow_archived_log_partial_model(api_or_ns: Namespace): """Build the workflow archived log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_archived_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) diff --git a/api/pyproject.toml b/api/pyproject.toml index 47847862a4..9d028cd58e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.12.0" +version = "1.12.1" requires-python = ">=3.11,<3.13" dependencies = [ @@ -145,7 +145,7 @@ dev = [ "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", "types-protobuf~=5.29.1", - "types-psutil~=7.0.0", + "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", diff --git a/api/services/account_service.py b/api/services/account_service.py index 35e4a505af..d3893c1207 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -327,6 +327,17 @@ class AccountService: @staticmethod def delete_account(account: Account): """Delete account. This method only adds a task to the queue for deletion.""" + # Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only) + from services.enterprise.account_deletion_sync import sync_account_deletion + + sync_success = sync_account_deletion(account_id=account.id, source="account_deleted") + if not sync_success: + logger.warning( + "Enterprise account deletion sync failed for account %s; proceeding with local deletion.", + account.id, + ) + + # Now proceed with async account deletion delete_account_task.delay(account.id) @staticmethod @@ -1230,6 +1241,19 @@ class TenantService: if dify_config.BILLING_ENABLED: BillingService.clean_billing_info_cache(tenant.id) + # Queue account deletion sync task for enterprise backend to reassign resources (enterprise only) + from services.enterprise.account_deletion_sync import sync_workspace_member_removal + + sync_success = sync_workspace_member_removal( + workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed" + ) + if not sync_success: + logger.warning( + "Enterprise workspace member removal sync failed: workspace_id=%s, member_id=%s", + tenant.id, + account.id, + ) + @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): """Update member role""" diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 56e9cc6a00..8ebc87a670 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -158,7 +158,7 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) - return annotations.items, annotations.total + return annotations.items, annotations.total or 0 @classmethod def export_annotation_list_by_app_id(cls, app_id: str): @@ -524,7 +524,7 @@ class AppAnnotationService: annotation_hit_histories = db.paginate( select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False ) - return annotation_hit_histories.items, annotation_hit_histories.total + return annotation_hit_histories.items, annotation_hit_histories.total or 0 @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0b3fcbe4ae..1ea6c4e1c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.file import helpers as file_helpers from core.helper.name_generator import generate_incremental_name @@ -1388,6 +1389,46 @@ class DocumentService: ).all() return documents + @staticmethod + def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int: + """ + Update need_summary field for multiple documents. + + This method handles the case where documents were created when summary_index_setting was disabled, + and need to be updated when summary_index_setting is later enabled. + + Args: + dataset_id: Dataset ID + document_ids: List of document IDs to update + need_summary: Value to set for need_summary field (default: True) + + Returns: + Number of documents updated + """ + if not document_ids: + return 0 + + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + + with session_factory.create_session() as session: + updated_count = ( + session.query(Document) + .filter( + Document.id.in_(document_id_list), + Document.dataset_id == dataset_id, + Document.doc_form != "qa_model", # Skip qa_model documents + ) + .update({Document.need_summary: need_summary}, synchronize_session=False) + ) + session.commit() + logger.info( + "Updated need_summary to %s for %d documents in dataset %s", + need_summary, + updated_count, + dataset_id, + ) + return updated_count + @staticmethod def get_document_download_url(document: Document) -> str: """ @@ -2937,14 +2978,15 @@ class DocumentService: """ now = naive_utc_now() - if action == "enable": - return DocumentService._prepare_enable_update(document, now) - elif action == "disable": - return DocumentService._prepare_disable_update(document, user, now) - elif action == "archive": - return DocumentService._prepare_archive_update(document, user, now) - elif action == "un_archive": - return DocumentService._prepare_unarchive_update(document, now) + match action: + case "enable": + return DocumentService._prepare_enable_update(document, now) + case "disable": + return DocumentService._prepare_disable_update(document, user, now) + case "archive": + return DocumentService._prepare_archive_update(document, user, now) + case "un_archive": + return DocumentService._prepare_unarchive_update(document, now) return None @@ -3581,56 +3623,57 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - if action == "enable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == False, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + match action: + case "enable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - elif action == "disable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == True, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.disabled_by = current_user.id - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + case "disable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py new file mode 100644 index 0000000000..c7ff42894d --- /dev/null +++ b/api/services/enterprise/account_deletion_sync.py @@ -0,0 +1,115 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin + +logger = logging.getLogger(__name__) + +ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue" +ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace" + + +def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Queue an account deletion sync task to Redis. + + Internal helper function. Do not call directly - use the public functions instead. + + Args: + workspace_id: The workspace/tenant ID to sync + member_id: The member/account ID that was removed + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "member_id": member_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + "type": ACCOUNT_DELETION_SYNC_TASK_TYPE, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s", + workspace_id, + member_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error( + "Failed to queue account deletion sync for workspace %s, member %s: %s", + workspace_id, + member_id, + str(e), + exc_info=True, + ) + # Don't raise - we don't want to fail member deletion if queueing fails + return False + + +def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Sync a single workspace member removal (enterprise only). + + Queues a task for the enterprise backend to reassign resources from the removed member. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + workspace_id: The workspace/tenant ID + member_id: The member/account ID that was removed + source: Source of the sync request (e.g., "workspace_member_removed") + + Returns: + bool: True if task was queued (or skipped in community), False if queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + +def sync_account_deletion(account_id: str, *, source: str) -> bool: + """ + Sync full account deletion across all workspaces (enterprise only). + + Fetches all workspace memberships for the account and queues a sync task for each. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + account_id: The account ID being deleted + source: Source of the sync request (e.g., "account_deleted") + + Returns: + bool: True if all tasks were queued (or skipped in community), False if any queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + # Fetch all workspaces the account belongs to + workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + + # Queue sync task for each workspace + success = True + for join in workspace_joins: + if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source): + success = False + + return success diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 8ea365e907..d0dfbc1070 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -174,6 +174,10 @@ class RagPipelineTransformService: else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Copy summary_index_setting from dataset to knowledge_index node configuration + if dataset.summary_index_setting: + knowledge_configuration.summary_index_setting = dataset.summary_index_setting + knowledge_configuration_dict.update(knowledge_configuration.model_dump()) node["data"] = knowledge_configuration_dict return node diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index b8e1f8bc3f..7c03ceed5b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -49,11 +49,18 @@ class SummaryIndexService: # Use lazy import to avoid circular import from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + # Get document language to ensure summary is generated in the correct language + # This is especially important for image-only chunks where text is empty or minimal + document_language = None + if segment.document and segment.document.doc_language: + document_language = segment.document.doc_language + summary_content, usage = ParagraphIndexProcessor.generate_summary( tenant_id=dataset.tenant_id, text=segment.content, summary_index_setting=summary_index_setting, segment_id=segment.id, + document_language=document_language, ) if not summary_content: @@ -558,6 +565,9 @@ class SummaryIndexService: ) session.add(summary_record) + # Commit the batch created records + session.commit() + @staticmethod def update_summary_record_error( segment: DocumentSegment, @@ -762,7 +772,6 @@ class SummaryIndexService: dataset=dataset, status="not_started", ) - session.commit() # Commit initial records summary_records = [] diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 56f4ae9494..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -24,7 +24,7 @@ class TagService: escaped_keyword = escape_like_pattern(keyword) query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) - results = query.order_by(Tag.created_at.desc()).all() + results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 817249845a..6240f2200f 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): - def del_workflow_archive_log(workflow_archive_log_id: str): - db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + def del_workflow_archive_log(session, workflow_archive_log_id: str): + session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) @@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f46d1bf5db..d020233620 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -10,7 +10,10 @@ from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile -from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch +from tasks.remove_app_and_related_data_task import ( + _delete_draft_variables, + delete_draft_variables_batch, +) @pytest.fixture @@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.return_value = None with session_factory.create_session() as session: draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = session.query(WorkflowDraftVariableFile).count() - upload_files_before = session.query(UploadFile).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.side_effect = [Exception("Storage error"), None] deleted_count = delete_draft_variables_batch(app_id, batch_size=10) @@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration: if app2_obj: session.delete(app2_obj) session.commit() + + +class TestDeleteDraftVariablesSessionCommit: + """Test suite to verify session commit behavior in delete_draft_variables_batch.""" + + @pytest.fixture + def setup_offload_test_data(self, app_and_tenant): + """Create test data with offload files for session commit tests.""" + from core.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now + + tenant, app = app_and_tenant + + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() + + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() + + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() + + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } + + yield data + + with session_factory.create_session() as session: + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() + + @pytest.fixture + def setup_commit_test_data(self, app_and_tenant): + """Create test data for session commit tests.""" + tenant, app = app_and_tenant + variable_ids: list[str] = [] + + with session_factory.create_session() as session: + variables = [] + for i in range(10): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] + + yield { + "app": app, + "tenant": tenant, + "variable_ids": variable_ids, + } + + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_query) + session.commit() + + def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data): + """Test that session.begin() is used for automatic transaction management.""" + data = setup_commit_test_data + app_id = data["app"].id + + # Since session.begin() is used, the transaction is automatically committed + # when the with block exits successfully. We verify this by checking that + # data is actually persisted. + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + # Verify all data was deleted (proves transaction was committed) + with session_factory.create_session() as session: + remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + + assert deleted_count == 10 + assert remaining_count == 0 + + def test_data_persisted_after_batch_deletion(self, setup_commit_test_data): + """Test that data is actually persisted to database after batch deletion with commits.""" + data = setup_commit_test_data + app_id = data["app"].id + variable_ids = data["variable_ids"] + + # Verify initial state + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Perform deletion with small batch size to force multiple commits + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + assert deleted_count == 10 + + # Verify all data is deleted in a new session (proves commits worked) + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + # Verify specific IDs are deleted + with session_factory.create_session() as session: + remaining_vars = ( + session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + ) + assert remaining_vars == 0 + + def test_session_commit_with_empty_dataset(self, setup_commit_test_data): + """Test session behavior when deleting from an empty dataset.""" + nonexistent_app_id = str(uuid.uuid4()) + + # Should not raise any errors and should return 0 + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10) + assert deleted_count == 0 + + def test_session_commit_with_single_batch(self, setup_commit_test_data): + """Test that commit happens correctly when all data fits in a single batch.""" + data = setup_commit_test_data + app_id = data["app"].id + + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Delete all in a single batch + deleted_count = delete_draft_variables_batch(app_id, batch_size=100) + assert deleted_count == 10 + + # Verify data is persisted + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + def test_invalid_batch_size_raises_error(self, setup_commit_test_data): + """Test that invalid batch size raises ValueError.""" + data = setup_commit_test_data + app_id = data["app"].id + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=-1) + + @patch("extensions.ext_storage.storage") + def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data): + """Test that session commits correctly when cleaning up offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + mock_storage.delete.return_value = None + + # Verify initial state + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_before == 3 + assert var_files_before == 2 + assert upload_files_before == 2 + + # Delete variables with offload data + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + assert deleted_count == 3 + + # Verify all data is persisted (deleted) in new session + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_after == 0 + assert var_files_after == 0 + assert upload_files_after == 0 + + # Verify storage cleanup was called + assert mock_storage.delete.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 4b6b5048a1..606e7e0b57 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -1016,7 +1016,7 @@ class TestAccountService: def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): """ - Test account deletion (should add task to queue). + Test account deletion (should add task to queue and sync to enterprise). """ fake = Faker() email = fake.email() @@ -1034,10 +1034,18 @@ class TestAccountService: password=password, ) - with patch("services.account_service.delete_account_task") as mock_delete_task: + with ( + patch("services.account_service.delete_account_task") as mock_delete_task, + patch("services.enterprise.account_deletion_sync.sync_account_deletion") as mock_sync, + ): + mock_sync.return_value = True + # Delete account AccountService.delete_account(account) + # Verify sync was called + mock_sync.assert_called_once_with(account_id=account.id, source="account_deleted") + # Verify task was added to queue mock_delete_task.delay.assert_called_once_with(account.id) @@ -1716,7 +1724,7 @@ class TestTenantService: def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): """ - Test successful member removal from tenant. + Test successful member removal from tenant (should sync to enterprise). """ fake = Faker() tenant_name = fake.company() @@ -1751,7 +1759,15 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, role="normal") # Remove member - TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: + mock_sync.return_value = True + + TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + + # Verify sync was called + mock_sync.assert_called_once_with( + workspace_id=tenant.id, member_id=member_account.id, source="workspace_member_removed" + ) # Verify member was removed from extensions.ext_database import db diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py deleted file mode 100644 index 68495dd979..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py +++ /dev/null @@ -1,291 +0,0 @@ -import builtins -import contextlib -import importlib -import sys -from unittest.mock import MagicMock, PropertyMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView -from werkzeug.exceptions import Unauthorized - -from extensions import ext_fastopenapi -from extensions.ext_database import db -from services.feature_service import FeatureModel, SystemFeatureModel - - -@pytest.fixture -def app(): - """ - Creates a Flask application instance configured for testing. - """ - app = Flask(__name__) - app.config["TESTING"] = True - app.config["SECRET_KEY"] = "test-secret" - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" - - # Initialize the database with the app - db.init_app(app) - - return app - - -@pytest.fixture(autouse=True) -def fix_method_view_issue(monkeypatch): - """ - Automatic fixture to patch 'builtins.MethodView'. - - Why this is needed: - The official legacy codebase contains a global patch in its initialization logic: - if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView - - Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly - rely on 'MethodView' being available in the global builtins namespace. - - Refactoring Note: - While patching builtins is generally discouraged due to global side effects, - this fixture reproduces the production environment's state to ensure tests are realistic. - We use 'monkeypatch' to ensure that this change is undone after the test finishes, - keeping other tests isolated. - """ - if not hasattr(builtins, "MethodView"): - # 'raising=False' allows us to set an attribute that doesn't exist yet - monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False) - - -# ------------------------------------------------------------------------------ -# Helper Functions for Fixture Complexity Reduction -# ------------------------------------------------------------------------------ - - -def _create_isolated_router(): - """ - Creates a fresh, isolated router instance to prevent route pollution. - """ - import controllers.fastopenapi - - # Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies - RouterClass = type(controllers.fastopenapi.console_router) - return RouterClass() - - -@contextlib.contextmanager -def _patch_auth_and_router(temp_router): - """ - Context manager that applies all necessary patches for: - 1. The console_router (redirecting to our isolated temp_router) - 2. Authentication decorators (disabling them with no-ops) - 3. User/Account loaders (mocking authenticated state) - """ - - def noop(f): - return f - - # We patch the SOURCE of the decorators/functions, not the destination module. - # This ensures that when 'controllers.console.feature' imports them, it gets the mocks. - with ( - patch("controllers.fastopenapi.console_router", temp_router), - patch("extensions.ext_fastopenapi.console_router", temp_router), - patch("controllers.console.wraps.setup_required", side_effect=noop), - patch("libs.login.login_required", side_effect=noop), - patch("controllers.console.wraps.account_initialization_required", side_effect=noop), - patch("controllers.console.wraps.cloud_utm_record", side_effect=noop), - patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")), - patch("libs.login.current_user", MagicMock(is_authenticated=True)), - ): - # Explicitly reload ext_fastopenapi to ensure it uses the patched console_router - import extensions.ext_fastopenapi - - importlib.reload(extensions.ext_fastopenapi) - - yield - - -def _force_reload_module(target_module: str, alias_module: str): - """ - Forces a reload of the specified module and handles sys.modules aliasing. - - Why reload? - Python decorators (like @route, @login_required) run at IMPORT time. - To apply our patches (mocks/no-ops) to these decorators, we must re-import - the module while the patches are active. - - Why alias? - If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import - it as 'controllers...', Python treats them as two separate modules. This causes: - 1. Double execution of decorators (registering routes twice -> AssertionError). - 2. Type mismatch errors (Class A from module X is not Class A from module Y). - - This function ensures both names point to the SAME loaded module instance. - """ - # 1. Clean existing entries to force re-import - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - # 2. Import the module (triggering decorators with active patches) - module = importlib.import_module(target_module) - - # 3. Alias the module in sys.modules to prevent double loading - sys.modules[alias_module] = sys.modules[target_module] - - return module - - -def _cleanup_modules(target_module: str, alias_module: str): - """ - Removes the module and its alias from sys.modules to prevent side effects - on other tests. - """ - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - -@pytest.fixture -def mock_feature_module_env(): - """ - Sets up a mocked environment for the feature module. - - This fixture orchestrates: - 1. Creating an isolated router. - 2. Patching authentication and global dependencies. - 3. Reloading the controller module to apply patches to decorators. - 4. cleaning up sys.modules afterwards. - """ - target_module = "controllers.console.feature" - alias_module = "api.controllers.console.feature" - - # 1. Prepare isolated router - temp_router = _create_isolated_router() - - # 2. Apply patches - try: - with _patch_auth_and_router(temp_router): - # 3. Reload module to register routes on the temp_router - feature_module = _force_reload_module(target_module, alias_module) - - yield feature_module - - finally: - # 4. Teardown: Clean up sys.modules - _cleanup_modules(target_module, alias_module) - - -# ------------------------------------------------------------------------------ -# Test Cases -# ------------------------------------------------------------------------------ - - -@pytest.mark.parametrize( - ("url", "service_mock_path", "mock_model_instance", "json_key"), - [ - ( - "/console/api/features", - "controllers.console.feature.FeatureService.get_features", - FeatureModel(can_replace_logo=True), - "features", - ), - ( - "/console/api/system-features", - "controllers.console.feature.FeatureService.get_system_features", - SystemFeatureModel(enable_marketplace=True), - "features", - ), - ], -) -def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key): - """ - Tests that the feature APIs return a 200 OK status and correct JSON structure. - """ - # Patch the service layer to return our mock model instance - with patch(service_mock_path, return_value=mock_model_instance): - # Initialize the API extension - ext_fastopenapi.init_app(app) - - client = app.test_client() - response = client.get(url) - - # Assertions - assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}" - - # Verify the JSON response matches the Pydantic model dump - expected_data = mock_model_instance.model_dump(mode="json") - assert response.get_json() == {json_key: expected_data} - - -@pytest.mark.parametrize( - ("url", "service_mock_path"), - [ - ("/console/api/features", "controllers.console.feature.FeatureService.get_features"), - ("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"), - ], -) -def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path): - """ - Tests how the application handles Service layer errors. - - Note: When an exception occurs in the view, it is typically caught by the framework - (Flask or the OpenAPI wrapper) and converted to a 500 error response. - This test verifies that the application returns a 500 status code. - """ - # Simulate a service failure - with patch(service_mock_path, side_effect=ValueError("Service Failure")): - ext_fastopenapi.init_app(app) - client = app.test_client() - - # When an exception occurs in the view, it is typically caught by the framework - # (Flask or the OpenAPI wrapper) and converted to a 500 error response. - response = client.get(url) - - assert response.status_code == 500 - # Check if the error details are exposed in the response (depends on error handler config) - # We accept either generic 500 or the specific error message - assert "Service Failure" in response.text or "Internal Server Error" in response.text - - -def test_system_features_unauthenticated(app, mock_feature_module_env): - """ - Tests that /console/api/system-features endpoint works without authentication. - - This test verifies the try-except block in get_system_features that handles - unauthenticated requests by passing is_authenticated=False to the service layer. - """ - feature_module = mock_feature_module_env - - # Override the behavior of the current_user mock - # The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user' - # refers to that same Mock object. - mock_user = feature_module.current_user - - # Simulate property access raising Unauthorized - # Note: We must reset side_effect if it was set, or set it here. - # The fixture initialized it as MagicMock(is_authenticated=True). - # We want type(mock_user).is_authenticated to raise Unauthorized. - type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized) - - # Patch the service layer for this specific test - with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service: - # Setup mock service return value - mock_model = SystemFeatureModel(enable_marketplace=True) - mock_service.return_value = mock_model - - # Initialize app - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.get("/console/api/system-features") - - # Assert - assert response.status_code == 200, f"Request failed: {response.text}" - - # Verify service was called with is_authenticated=False - mock_service.assert_called_once_with(is_authenticated=False) - - # Verify response body - expected_data = mock_model.model_dump(mode="json") - assert response.get_json() == {"features": expected_data} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py deleted file mode 100644 index 62d143f32d..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py +++ /dev/null @@ -1,222 +0,0 @@ -import builtins -import contextlib -import importlib -import sys -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView - -from extensions import ext_fastopenapi -from extensions.ext_database import db - - -@pytest.fixture -def app(): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["SECRET_KEY"] = "test-secret" - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" - - db.init_app(app) - - return app - - -@pytest.fixture(autouse=True) -def fix_method_view_issue(monkeypatch): - if not hasattr(builtins, "MethodView"): - monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False) - - -def _create_isolated_router(): - import controllers.fastopenapi - - router_class = type(controllers.fastopenapi.console_router) - return router_class() - - -@contextlib.contextmanager -def _patch_auth_and_router(temp_router): - def noop(func): - return func - - default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False) - - with ( - patch("controllers.fastopenapi.console_router", temp_router), - patch("extensions.ext_fastopenapi.console_router", temp_router), - patch("controllers.console.wraps.setup_required", side_effect=noop), - patch("libs.login.login_required", side_effect=noop), - patch("controllers.console.wraps.account_initialization_required", side_effect=noop), - patch("controllers.console.wraps.edit_permission_required", side_effect=noop), - patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")), - patch("configs.dify_config.EDITION", "CLOUD"), - ): - import extensions.ext_fastopenapi - - importlib.reload(extensions.ext_fastopenapi) - - yield - - -def _force_reload_module(target_module: str, alias_module: str): - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - module = importlib.import_module(target_module) - sys.modules[alias_module] = sys.modules[target_module] - - return module - - -def _dedupe_routes(router): - seen = set() - unique_routes = [] - for path, method, endpoint in reversed(router.get_routes()): - key = (path, method, endpoint.__name__) - if key in seen: - continue - seen.add(key) - unique_routes.append((path, method, endpoint)) - router._routes = list(reversed(unique_routes)) - - -def _cleanup_modules(target_module: str, alias_module: str): - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - -@pytest.fixture -def mock_tags_module_env(): - target_module = "controllers.console.tag.tags" - alias_module = "api.controllers.console.tag.tags" - temp_router = _create_isolated_router() - - try: - with _patch_auth_and_router(temp_router): - tags_module = _force_reload_module(target_module, alias_module) - _dedupe_routes(temp_router) - yield tags_module - finally: - _cleanup_modules(target_module, alias_module) - - -def test_list_tags_success(app: Flask, mock_tags_module_env): - # Arrange - tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2) - with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]): - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.get("/console/api/tags?type=app&keyword=Alpha") - - # Assert - assert response.status_code == 200 - assert response.get_json() == [ - {"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2}, - ] - - -def test_create_tag_success(app: Flask, mock_tags_module_env): - # Arrange - tag = SimpleNamespace(id="tag-2", name="Beta", type="app") - with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save: - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"}) - - # Assert - assert response.status_code == 200 - assert response.get_json() == { - "id": "tag-2", - "name": "Beta", - "type": "app", - "binding_count": 0, - } - mock_save.assert_called_once_with({"name": "Beta", "type": "app"}) - - -def test_update_tag_success(app: Flask, mock_tags_module_env): - # Arrange - tag = SimpleNamespace(id="tag-3", name="Gamma", type="app") - with ( - patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update, - patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4), - ): - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.patch( - "/console/api/tags/11111111-1111-1111-1111-111111111111", - json={"name": "Gamma", "type": "app"}, - ) - - # Assert - assert response.status_code == 200 - assert response.get_json() == { - "id": "tag-3", - "name": "Gamma", - "type": "app", - "binding_count": 4, - } - mock_update.assert_called_once_with( - {"name": "Gamma", "type": "app"}, - "11111111-1111-1111-1111-111111111111", - ) - - -def test_delete_tag_success(app: Flask, mock_tags_module_env): - # Arrange - with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete: - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111") - - # Assert - assert response.status_code == 204 - mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111") - - -def test_create_tag_binding_success(app: Flask, mock_tags_module_env): - # Arrange - payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"} - with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind: - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.post("/console/api/tag-bindings/create", json=payload) - - # Assert - assert response.status_code == 200 - assert response.get_json() == {"result": "success"} - mock_bind.assert_called_once_with(payload) - - -def test_delete_tag_binding_success(app: Flask, mock_tags_module_env): - # Arrange - payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"} - with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind: - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.post("/console/api/tag-bindings/remove", json=payload) - - # Assert - assert response.status_code == 200 - assert response.get_json() == {"result": "success"} - mock_unbind.assert_called_once_with(payload) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py new file mode 100644 index 0000000000..94c3019d5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -0,0 +1,364 @@ +"""Endpoint tests for controllers.console.workspace.tool_providers.""" + +from __future__ import annotations + +import builtins +import importlib +from contextlib import contextmanager +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +_CONTROLLER_MODULE: ModuleType | None = None +_WRAPS_MODULE: ModuleType | None = None +_CONTROLLER_PATCHERS: list[patch] = [] + + +@contextmanager +def _mock_db(): + mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + with patch("extensions.ext_database.db.session", mock_session): + yield + + +@pytest.fixture +def app() -> Flask: + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def controller_module(monkeypatch: pytest.MonkeyPatch): + module_name = "controllers.console.workspace.tool_providers" + global _CONTROLLER_MODULE + if _CONTROLLER_MODULE is None: + + def _noop(func): + return func + + patch_targets = [ + ("libs.login.login_required", _noop), + ("controllers.console.wraps.setup_required", _noop), + ("controllers.console.wraps.account_initialization_required", _noop), + ("controllers.console.wraps.is_admin_or_owner_required", _noop), + ("controllers.console.wraps.enterprise_license_required", _noop), + ] + for target, value in patch_targets: + patcher = patch(target, value) + patcher.start() + _CONTROLLER_PATCHERS.append(patcher) + monkeypatch.setenv("DIFY_SETUP_READY", "true") + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) + + module = _CONTROLLER_MODULE + monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) + + # Ensure decorators that consult deployment edition do not reach the database. + global _WRAPS_MODULE + wraps_module = importlib.import_module("controllers.console.wraps") + _WRAPS_MODULE = wraps_module + monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None) + return module + + +def _mock_account(user_id: str = "user-123") -> SimpleNamespace: + return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None) + + +def _set_current_account( + monkeypatch: pytest.MonkeyPatch, + controller_module: ModuleType, + user: SimpleNamespace, + tenant_id: str, +) -> None: + def _getter(): + return user, tenant_id + + user.current_tenant_id = tenant_id + + monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter) + if _WRAPS_MODULE is not None: + monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter) + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "_get_user", lambda: user) + + +def test_tool_provider_list_calls_service_with_query( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value=[{"provider": "builtin"}]) + monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock) + + with app.test_request_context("/workspaces/current/tool-providers?type=builtin"): + response = controller_module.ToolProviderListApi().get() + + assert response == [{"provider": "builtin"}] + service_mock.assert_called_once_with(user.id, "tenant-456", "builtin") + + +def test_builtin_provider_add_passes_payload( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value={"status": "ok"}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock) + + payload = { + "credentials": {"api_key": "sk-test"}, + "name": "MyTool", + "type": controller_module.CredentialType.API_KEY, + } + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/openai/add", + method="POST", + json=payload, + ): + response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai") + + assert response == {"status": "ok"} + service_mock.assert_called_once_with( + user_id="user-123", + tenant_id="tenant-456", + provider="openai", + credentials={"api_key": "sk-test"}, + name="MyTool", + api_type=controller_module.CredentialType.API_KEY, + ) + + +def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-789") + _set_current_account(monkeypatch, controller_module, user, "tenant-789") + + service_mock = MagicMock(return_value=[{"name": "tool-a"}]) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock) + monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload) + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/my-provider/tools", + method="GET", + ): + response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider") + + assert response == [{"name": "tool-a"}] + service_mock.assert_called_once_with("tenant-789", "my-provider") + + +def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-9") + _set_current_account(monkeypatch, controller_module, user, "tenant-9") + service_mock = MagicMock(return_value={"info": True}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock) + + with app.test_request_context("/info", method="GET"): + resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo") + + assert resp == {"info": True} + service_mock.assert_called_once_with("tenant-9", "demo") + + +def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-cred") + _set_current_account(monkeypatch, controller_module, user, "tenant-cred") + service_mock = MagicMock(return_value=[{"cred": 1}]) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "get_builtin_tool_provider_credentials", + service_mock, + ) + + with app.test_request_context("/creds", method="GET"): + resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo") + + assert resp == [{"cred": 1}] + service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo") + + +def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-10") + service_mock = MagicMock(return_value={"schema": "ok"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock) + + with app.test_request_context("/remote?url=https://example.com/"): + resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get() + + assert resp == {"schema": "ok"} + service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/") + + +def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-11") + service_mock = MagicMock(return_value=[{"tool": "t"}]) + monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock) + + with app.test_request_context("/tools?provider=foo"): + resp = controller_module.ToolApiProviderListToolsApi().get() + + assert resp == [{"tool": "t"}] + service_mock.assert_called_once_with(user.id, "tenant-11", "foo") + + +def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-12") + service_mock = MagicMock(return_value={"provider": "foo"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock) + + with app.test_request_context("/get?provider=foo"): + resp = controller_module.ToolApiProviderGetApi().get() + + assert resp == {"provider": "foo"} + service_mock.assert_called_once_with(user.id, "tenant-12", "foo") + + +def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-13") + _set_current_account(monkeypatch, controller_module, user, "tenant-13") + service_mock = MagicMock(return_value={"schema": True}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_provider_credentials_schema", + service_mock, + ) + + with app.test_request_context("/schema", method="GET"): + resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get( + provider="demo", credential_type="api-key" + ) + + assert resp == {"schema": True} + service_mock.assert_called_once() + + +def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf") + tool_service = MagicMock(return_value={"wf": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_tool_id", + tool_service, + ) + + tool_id = "00000000-0000-0000-0000-000000000001" + with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"wf": 1} + tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id) + + +def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf2") + service_mock = MagicMock(return_value={"app": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_app_id", + service_mock, + ) + + app_id = "00000000-0000-0000-0000-000000000002" + with app.test_request_context(f"/workflow?workflow_app_id={app_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"app": 1} + service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id) + + +def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf3") + service_mock = MagicMock(return_value=[{"id": 1}]) + monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock) + + tool_id = "00000000-0000-0000-0000-000000000003" + with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderListToolApi().get() + + assert resp == [{"id": 1}] + service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id) + + +def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-bt") + + provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/builtin"): + resp = controller_module.ToolBuiltinListApi().get() + + assert resp == [{"name": "builtin"}] + + +def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-api") + _set_current_account(monkeypatch, controller_module, user, "tenant-api") + + provider = SimpleNamespace(to_dict=lambda: {"name": "api"}) + monkeypatch.setattr( + controller_module.ApiToolManageService, + "list_api_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/api"): + resp = controller_module.ToolApiListApi().get() + + assert resp == [{"name": "api"}] + + +def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf4") + + provider = SimpleNamespace(to_dict=lambda: {"name": "wf"}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "list_tenant_workflow_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/workflow"): + resp = controller_module.ToolWorkflowListApi().get() + + assert resp == [{"name": "wf"}] + + +def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-label") + _set_current_account(monkeypatch, controller_module, user, "tenant-labels") + monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"]) + + with app.test_request_context("/tool-labels"): + resp = controller_module.ToolLabelsApi().get() + + assert resp == ["a", "b"] diff --git a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 0000000000..b66111902c --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,276 @@ +"""Unit tests for account deletion synchronization. + +This test module verifies the enterprise account deletion sync functionality, +including Redis queuing, error handling, and community vs enterprise behavior. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + """Unit tests for the _queue_task helper function.""" + + @pytest.fixture + def mock_redis_client(self): + """Mock redis_client for testing.""" + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + yield mock_redis + + @pytest.fixture + def mock_uuid(self): + """Mock UUID generation for predictable task IDs.""" + with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen: + mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234") + yield mock_uuid_gen + + def test_queue_task_success(self, mock_redis_client, mock_uuid): + """Test successful task queueing to Redis.""" + # Arrange + workspace_id = "ws-123" + member_id = "member-456" + source = "test_source" + + # Act + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + # Assert + assert result is True + mock_redis_client.lpush.assert_called_once() + + # Verify the task payload structure + call_args = mock_redis_client.lpush.call_args[0] + assert call_args[0] == "enterprise:member:sync:queue" + + import json + + task_data = json.loads(call_args[1]) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == source + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, mock_redis_client, caplog): + """Test handling of Redis connection errors.""" + # Arrange + mock_redis_client.lpush.side_effect = RedisError("Connection failed") + + # Act + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, mock_redis_client, caplog): + """Test handling of JSON serialization errors.""" + # Arrange + mock_redis_client.lpush.side_effect = TypeError("Cannot serialize") + + # Act + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + """Unit tests for sync_workspace_member_removal function.""" + + @pytest.fixture + def mock_queue_task(self): + """Mock _queue_task for testing.""" + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is True.""" + # Arrange + workspace_id = "ws-123" + member_id = "member-456" + source = "workspace_member_removed" + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source) + + # Assert + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source) + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is False (community edition).""" + # Arrange + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + # Act + result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + """Test handling of queue task failures.""" + # Arrange + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + + +class TestSyncAccountDeletion: + """Unit tests for sync_account_deletion function.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session for testing.""" + with patch("services.enterprise.account_deletion_sync.db.session") as mock_session: + yield mock_session + + @pytest.fixture + def mock_queue_task(self): + """Mock _queue_task for testing.""" + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is False (community edition).""" + # Arrange + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is True + mock_db_session.query.assert_not_called() + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task): + """Test sync for account with multiple workspace memberships.""" + # Arrange + account_id = "acc-123" + + # Mock workspace joins + mock_join1 = MagicMock() + mock_join1.tenant_id = "tenant-1" + mock_join2 = MagicMock() + mock_join2.tenant_id = "tenant-2" + mock_join3 = MagicMock() + mock_join3.tenant_id = "tenant-3" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] + mock_db_session.query.return_value = mock_query + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + # Assert + assert result is True + assert mock_queue_task.call_count == 3 + + # Verify each workspace was queued + mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted") + mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted") + mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted") + + def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task): + """Test sync for account with no workspace memberships.""" + # Arrange + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [] + mock_db_session.query.return_value = mock_query + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task): + """Test sync when some tasks fail to queue.""" + # Arrange + account_id = "acc-123" + + # Mock workspace joins + mock_join1 = MagicMock() + mock_join1.tenant_id = "tenant-1" + mock_join2 = MagicMock() + mock_join2.tenant_id = "tenant-2" + mock_join3 = MagicMock() + mock_join3.tenant_id = "tenant-3" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] + mock_db_session.query.return_value = mock_query + + # Mock queue_task to fail for second workspace + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != "tenant-2" + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + # Assert + assert result is False # Should return False if any task fails + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task): + """Test sync when all tasks fail to queue.""" + # Arrange + mock_join = MagicMock() + mock_join.tenant_id = "tenant-1" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join] + mock_db_session.query.return_value = mock_query + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is False + mock_queue_task.assert_called_once() diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index a14bbb01d0..2b11e42cd5 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs: mock_query.where.return_value = mock_delete_query mock_db.session.query.return_value = mock_query - delete_func("log-1") + delete_func(mock_db.session, "log-1") mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) mock_query.where.assert_called_once() diff --git a/api/uv.lock b/api/uv.lock index 04d9a7c021..0a17741f9a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.12.0" +version = "1.12.1" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1707,7 +1707,7 @@ dev = [ { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, { name = "types-protobuf", specifier = "~=5.29.1" }, - { name = "types-psutil", specifier = "~=7.0.0" }, + { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, { name = "types-pymysql", specifier = "~=1.1.0" }, @@ -6508,11 +6508,11 @@ wheels = [ [[package]] name = "types-psutil" -version = "7.0.0.20251116" +version = "7.2.2.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" }, + { url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" }, ] [[package]] diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index 7c39a48bf4..a034083304 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -x +set -euxo pipefail SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index e27b51bcc0..cb5e2c47f7 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -662,13 +662,14 @@ services: - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" - "${IRIS_WEB_SERVER_PORT:-52773}:52773" volumes: - - ./volumes/iris:/opt/iris + - ./volumes/iris:/durable - ./iris/iris-init.script:/iris-init.script - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh entrypoint: ["/custom-entrypoint.sh"] tty: true environment: TZ: ${IRIS_TIMEZONE:-UTC} + ISC_DATA_DIRECTORY: /durable/iris # Oracle vector database oracle: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a5518ceee9..161fdc6c3f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -712,7 +712,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -754,7 +754,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -793,7 +793,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -823,7 +823,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -1353,13 +1353,14 @@ services: - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" - "${IRIS_WEB_SERVER_PORT:-52773}:52773" volumes: - - ./volumes/iris:/opt/iris + - ./volumes/iris:/durable - ./iris/iris-init.script:/iris-init.script - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh entrypoint: ["/custom-entrypoint.sh"] tty: true environment: TZ: ${IRIS_TIMEZONE:-UTC} + ISC_DATA_DIRECTORY: /durable/iris # Oracle vector database oracle: diff --git a/docker/iris/docker-entrypoint.sh b/docker/iris/docker-entrypoint.sh index 067bfa03e2..1a3b10423b 100755 --- a/docker/iris/docker-entrypoint.sh +++ b/docker/iris/docker-entrypoint.sh @@ -1,15 +1,33 @@ #!/bin/bash set -e -# IRIS configuration flag file -IRIS_CONFIG_DONE="/opt/iris/.iris-configured" +# IRIS configuration flag file (stored in durable directory to persist with data) +IRIS_CONFIG_DONE="/durable/.iris-configured" + +# Function to wait for IRIS to be ready +wait_for_iris() { + echo "Waiting for IRIS to be ready..." + local max_attempts=30 + local attempt=1 + while [ "$attempt" -le "$max_attempts" ]; do + if iris qlist IRIS 2>/dev/null | grep -q "running"; then + echo "IRIS is ready." + return 0 + fi + echo "Attempt $attempt/$max_attempts: IRIS not ready yet, waiting..." + sleep 2 + attempt=$((attempt + 1)) + done + echo "ERROR: IRIS failed to start within expected time." >&2 + return 1 +} # Function to configure IRIS configure_iris() { echo "Configuring IRIS for first-time setup..." # Wait for IRIS to be fully started - sleep 5 + wait_for_iris # Execute the initialization script iris session IRIS < /iris-init.script diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index 3410ecbe9a..dfbac5d743 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -3,7 +3,7 @@ import type { ReactNode } from 'react' import Cookies from 'js-cookie' import { usePathname, useRouter, useSearchParams } from 'next/navigation' -import { parseAsString, useQueryState } from 'nuqs' +import { parseAsBoolean, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, @@ -28,7 +28,7 @@ export const AppInitializer = ({ const [init, setInit] = useState(false) const [oauthNewUser, setOauthNewUser] = useQueryState( 'oauth_new_user', - parseAsString.withOptions({ history: 'replace' }), + parseAsBoolean.withOptions({ history: 'replace' }), ) const isSetupFinished = useCallback(async () => { @@ -46,7 +46,7 @@ export const AppInitializer = ({ (async () => { const action = searchParams.get('action') - if (oauthNewUser === 'true') { + if (oauthNewUser) { let utmInfo = null const utmInfoStr = Cookies.get('utm_info') if (utmInfoStr) { diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 15cfbd5411..e203edfc8c 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -62,19 +62,19 @@ const AppCard = ({ {app.description} - {canCreate && ( + {(canCreate || isTrialApp) && ( )} diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index cb8f4db67f..8c368df62c 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -1,3 +1,4 @@ +import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useRouter } from 'next/navigation' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' @@ -13,8 +14,8 @@ import { getRedirection } from '@/utils/app-redirection' import CreateAppModal from './index' vi.mock('ahooks', () => ({ - useDebounceFn: (fn: (...args: any[]) => any) => { - const run = (...args: any[]) => fn(...args) + useDebounceFn: unknown>(fn: T) => { + const run = (...args: Parameters) => fn(...args) const cancel = vi.fn() const flush = vi.fn() return { run, cancel, flush } @@ -83,7 +84,7 @@ describe('CreateAppModal', () => { beforeEach(() => { vi.clearAllMocks() - mockUseRouter.mockReturnValue({ push: mockPush } as any) + mockUseRouter.mockReturnValue({ push: mockPush } as unknown as ReturnType) mockUseProviderContext.mockReturnValue({ plan: { type: AppModeEnum.ADVANCED_CHAT, @@ -92,10 +93,10 @@ describe('CreateAppModal', () => { reset: {}, }, enableBilling: true, - } as any) + } as unknown as ReturnType) mockUseAppContext.mockReturnValue({ isCurrentWorkspaceEditor: true, - } as any) + } as unknown as ReturnType) mockSetItem.mockClear() Object.defineProperty(window, 'localStorage', { value: { @@ -118,13 +119,13 @@ describe('CreateAppModal', () => { }) it('creates an app, notifies success, and fires callbacks', async () => { - const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } - mockCreateApp.mockResolvedValue(mockApp as any) + const mockApp: Partial = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } + mockCreateApp.mockResolvedValue(mockApp as App) const { onClose, onSuccess } = renderModal() const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({ name: 'My App', @@ -152,7 +153,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) diff --git a/web/app/components/datasets/common/image-uploader/utils.spec.ts b/web/app/components/datasets/common/image-uploader/utils.spec.ts index 0150b1fb23..5741f5704f 100644 --- a/web/app/components/datasets/common/image-uploader/utils.spec.ts +++ b/web/app/components/datasets/common/image-uploader/utils.spec.ts @@ -216,13 +216,22 @@ describe('image-uploader utils', () => { type FileCallback = (file: MockFile) => void type EntriesCallback = (entries: FileSystemEntry[]) => void + // Helper to create mock FileSystemEntry with required properties + const createMockEntry = (props: { + isFile: boolean + isDirectory: boolean + name?: string + file?: (callback: FileCallback) => void + createReader?: () => { readEntries: (callback: EntriesCallback) => void } + }): FileSystemEntry => props as unknown as FileSystemEntry + it('should resolve with file array for file entry', async () => { const mockFile: MockFile = { name: 'test.png' } - const mockEntry = { + const mockEntry = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile), - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toHaveLength(1) @@ -232,11 +241,11 @@ describe('image-uploader utils', () => { it('should resolve with file array with prefix for nested file', async () => { const mockFile: MockFile = { name: 'test.png' } - const mockEntry = { + const mockEntry = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile), - } + }) const result = await traverseFileEntry(mockEntry, 'folder/') expect(result).toHaveLength(1) @@ -244,24 +253,24 @@ describe('image-uploader utils', () => { }) it('should resolve empty array for unknown entry type', async () => { - const mockEntry = { + const mockEntry = createMockEntry({ isFile: false, isDirectory: false, - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toEqual([]) }) it('should handle directory with no files', async () => { - const mockEntry = { + const mockEntry = createMockEntry({ isFile: false, isDirectory: true, name: 'empty-folder', createReader: () => ({ readEntries: (callback: EntriesCallback) => callback([]), }), - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toEqual([]) @@ -271,20 +280,20 @@ describe('image-uploader utils', () => { const mockFile1: MockFile = { name: 'file1.png' } const mockFile2: MockFile = { name: 'file2.png' } - const mockFileEntry1 = { + const mockFileEntry1 = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile1), - } + }) - const mockFileEntry2 = { + const mockFileEntry2 = createMockEntry({ isFile: true, isDirectory: false, file: (callback: FileCallback) => callback(mockFile2), - } + }) let readCount = 0 - const mockEntry = { + const mockEntry = createMockEntry({ isFile: false, isDirectory: true, name: 'folder', @@ -292,14 +301,14 @@ describe('image-uploader utils', () => { readEntries: (callback: EntriesCallback) => { if (readCount === 0) { readCount++ - callback([mockFileEntry1, mockFileEntry2] as unknown as FileSystemEntry[]) + callback([mockFileEntry1, mockFileEntry2]) } else { callback([]) } }, }), - } + }) const result = await traverseFileEntry(mockEntry) expect(result).toHaveLength(2) diff --git a/web/app/components/datasets/common/image-uploader/utils.ts b/web/app/components/datasets/common/image-uploader/utils.ts index c2fad83840..d8c8582e2a 100644 --- a/web/app/components/datasets/common/image-uploader/utils.ts +++ b/web/app/components/datasets/common/image-uploader/utils.ts @@ -18,17 +18,17 @@ type FileWithPath = { relativePath?: string } & File -export const traverseFileEntry = (entry: any, prefix = ''): Promise => { +export const traverseFileEntry = (entry: FileSystemEntry, prefix = ''): Promise => { return new Promise((resolve) => { if (entry.isFile) { - entry.file((file: FileWithPath) => { + (entry as FileSystemFileEntry).file((file: FileWithPath) => { file.relativePath = `${prefix}${file.name}` resolve([file]) }) } else if (entry.isDirectory) { - const reader = entry.createReader() - const entries: any[] = [] + const reader = (entry as FileSystemDirectoryEntry).createReader() + const entries: FileSystemEntry[] = [] const read = () => { reader.readEntries(async (results: FileSystemEntry[]) => { if (!results.length) { diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx new file mode 100644 index 0000000000..e4955f58f6 --- /dev/null +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.spec.tsx @@ -0,0 +1,1045 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { CreateFromDSLModalTab, useDSLImport } from './use-dsl-import' + +// Mock next/navigation +const mockPush = vi.fn() +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +// Mock service hooks +const mockImportDSL = vi.fn() +const mockImportDSLConfirm = vi.fn() + +vi.mock('@/service/use-pipeline', () => ({ + useImportPipelineDSL: () => ({ + mutateAsync: mockImportDSL, + }), + useImportPipelineDSLConfirm: () => ({ + mutateAsync: mockImportDSLConfirm, + }), +})) + +// Mock plugin dependencies hook +const mockHandleCheckPluginDependencies = vi.fn() + +vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ + usePluginDependencies: () => ({ + handleCheckPluginDependencies: mockHandleCheckPluginDependencies, + }), +})) + +// Mock toast context +const mockNotify = vi.fn() + +vi.mock('use-context-selector', async () => { + const actual = await vi.importActual('use-context-selector') + return { + ...actual, + useContext: vi.fn(() => ({ notify: mockNotify })), + } +}) + +// Test data builders +const createImportDSLResponse = (overrides = {}) => ({ + id: 'import-123', + status: 'completed' as const, + pipeline_id: 'pipeline-456', + dataset_id: 'dataset-789', + current_dsl_version: '1.0.0', + imported_dsl_version: '1.0.0', + ...overrides, +}) + +// Helper function to create QueryClient wrapper +const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, + }) + return ({ children }: { children: ReactNode }) => ( + + {children} + + ) +} + +describe('useDSLImport', () => { + beforeEach(() => { + vi.clearAllMocks() + mockImportDSL.mockReset() + mockImportDSLConfirm.mockReset() + mockPush.mockReset() + mockNotify.mockReset() + mockHandleCheckPluginDependencies.mockReset() + }) + + describe('initialization', () => { + it('should initialize with default values', () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + expect(result.current.currentFile).toBeUndefined() + expect(result.current.currentTab).toBe(CreateFromDSLModalTab.FROM_FILE) + expect(result.current.dslUrlValue).toBe('') + expect(result.current.showConfirmModal).toBe(false) + expect(result.current.versions).toBeUndefined() + expect(result.current.buttonDisabled).toBe(true) + expect(result.current.isConfirming).toBe(false) + }) + + it('should use provided activeTab', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_URL }), + { wrapper: createWrapper() }, + ) + + expect(result.current.currentTab).toBe(CreateFromDSLModalTab.FROM_URL) + }) + + it('should use provided dslUrl', () => { + const { result } = renderHook( + () => useDSLImport({ dslUrl: 'https://example.com/test.pipeline' }), + { wrapper: createWrapper() }, + ) + + expect(result.current.dslUrlValue).toBe('https://example.com/test.pipeline') + }) + }) + + describe('setCurrentTab', () => { + it('should update current tab', () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.setCurrentTab(CreateFromDSLModalTab.FROM_URL) + }) + + expect(result.current.currentTab).toBe(CreateFromDSLModalTab.FROM_URL) + }) + }) + + describe('setDslUrlValue', () => { + it('should update DSL URL value', () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.setDslUrlValue('https://new-url.com/pipeline') + }) + + expect(result.current.dslUrlValue).toBe('https://new-url.com/pipeline') + }) + }) + + describe('handleFile', () => { + it('should set file and trigger file reading', async () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['test content'], 'test.pipeline', { type: 'application/octet-stream' }) + + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.currentFile).toBe(mockFile) + expect(result.current.buttonDisabled).toBe(false) + }) + + it('should clear file when undefined is passed', async () => { + const { result } = renderHook( + () => useDSLImport({}), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['test content'], 'test.pipeline', { type: 'application/octet-stream' }) + + // First set a file + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.currentFile).toBe(mockFile) + + // Then clear it + await act(async () => { + result.current.handleFile(undefined) + }) + + expect(result.current.currentFile).toBeUndefined() + expect(result.current.buttonDisabled).toBe(true) + }) + }) + + describe('buttonDisabled', () => { + it('should be true when file tab is active and no file is selected', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + expect(result.current.buttonDisabled).toBe(true) + }) + + it('should be false when file tab is active and file is selected', async () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pipeline', { type: 'application/octet-stream' }) + + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.buttonDisabled).toBe(false) + }) + + it('should be true when URL tab is active and no URL is entered', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_URL }), + { wrapper: createWrapper() }, + ) + + expect(result.current.buttonDisabled).toBe(true) + }) + + it('should be false when URL tab is active and URL is entered', () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_URL, dslUrl: 'https://example.com' }), + { wrapper: createWrapper() }, + ) + + expect(result.current.buttonDisabled).toBe(false) + }) + }) + + describe('handleCreateApp with URL mode', () => { + it('should call importDSL with URL mode', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse()) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) // Wait for debounce + }) + + await waitFor(() => { + expect(mockImportDSL).toHaveBeenCalledWith({ + mode: 'yaml-url', + yaml_url: 'https://example.com/test.pipeline', + }) + }) + + vi.useRealTimers() + }) + + it('should handle successful import with COMPLETED status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ status: 'completed' })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(onSuccess).toHaveBeenCalled() + expect(onClose).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-789/pipeline') + }) + + vi.useRealTimers() + }) + + it('should handle import with COMPLETED_WITH_WARNINGS status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ status: 'completed-with-warnings' })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'warning', + })) + }) + + vi.useRealTimers() + }) + + it('should handle import with PENDING status and show confirm modal', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'pending', + imported_dsl_version: '0.9.0', + current_dsl_version: '1.0.0', + })) + + const onClose = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onClose, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(onClose).toHaveBeenCalled() + }) + + // Wait for setTimeout to show confirm modal + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.showConfirmModal).toBe(true) + expect(result.current.versions).toEqual({ + importedVersion: '0.9.0', + systemVersion: '1.0.0', + }) + + vi.useRealTimers() + }) + + it('should handle API error (null response)', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(null) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should handle FAILED status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ status: 'failed' })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should check plugin dependencies when pipeline_id is present', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'completed', + pipeline_id: 'pipeline-123', + })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('pipeline-123', true) + }) + + vi.useRealTimers() + }) + + it('should not check plugin dependencies when pipeline_id is null', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'completed', + pipeline_id: null, + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).not.toHaveBeenCalled() + }) + + vi.useRealTimers() + }) + + it('should return early when URL tab is active but no URL is provided', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: '', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + expect(mockImportDSL).not.toHaveBeenCalled() + + vi.useRealTimers() + }) + }) + + describe('handleCreateApp with FILE mode', () => { + it('should call importDSL with file content mode', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse()) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_FILE, + }), + { wrapper: createWrapper() }, + ) + + const fileContent = 'test yaml content' + const mockFile = new File([fileContent], 'test.pipeline', { type: 'application/octet-stream' }) + + // Set up file and wait for FileReader to complete + await act(async () => { + result.current.handleFile(mockFile) + // Give FileReader time to process + await new Promise(resolve => setTimeout(resolve, 100)) + }) + + // Trigger create + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockImportDSL).toHaveBeenCalledWith({ + mode: 'yaml-content', + yaml_content: fileContent, + }) + }) + + vi.useRealTimers() + }) + + it('should return early when file tab is active but no file is selected', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_FILE, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + expect(mockImportDSL).not.toHaveBeenCalled() + + vi.useRealTimers() + }) + }) + + describe('onDSLConfirm', () => { + it('should call importDSLConfirm and handle success', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + // First, trigger pending status to get importId + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'completed', + pipeline_id: 'pipeline-456', + dataset_id: 'dataset-789', + }) + + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const onSuccess = vi.fn() + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + onSuccess, + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Wait for confirm modal to show + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.showConfirmModal).toBe(true) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-123') + expect(onSuccess).toHaveBeenCalled() + expect(result.current.showConfirmModal).toBe(false) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + + vi.useRealTimers() + }) + + it('should handle confirm API error', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue(null) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should handle confirm with FAILED status', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'failed', + pipeline_id: 'pipeline-456', + dataset_id: 'dataset-789', + }) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + + vi.useRealTimers() + }) + + it('should return early when importId is not set', async () => { + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Call onDSLConfirm without triggering pending status + await act(async () => { + result.current.onDSLConfirm() + }) + + expect(mockImportDSLConfirm).not.toHaveBeenCalled() + }) + + it('should check plugin dependencies on confirm success', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'completed', + pipeline_id: 'pipeline-789', + dataset_id: 'dataset-789', + }) + + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('pipeline-789', true) + }) + + vi.useRealTimers() + }) + + it('should set isConfirming during confirm process', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + let resolveConfirm: (value: unknown) => void + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockImplementation(() => new Promise((resolve) => { + resolveConfirm = resolve + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.isConfirming).toBe(false) + + // Start confirm + let confirmPromise: Promise + act(() => { + confirmPromise = result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(result.current.isConfirming).toBe(true) + }) + + // Resolve confirm + await act(async () => { + resolveConfirm!({ + status: 'completed', + pipeline_id: 'pipeline-789', + dataset_id: 'dataset-789', + }) + }) + + await confirmPromise! + + expect(result.current.isConfirming).toBe(false) + + vi.useRealTimers() + }) + }) + + describe('handleCancelConfirm', () => { + it('should close confirm modal', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status to show confirm modal + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + expect(result.current.showConfirmModal).toBe(true) + + // Cancel confirm + act(() => { + result.current.handleCancelConfirm() + }) + + expect(result.current.showConfirmModal).toBe(false) + + vi.useRealTimers() + }) + }) + + describe('duplicate submission prevention', () => { + it('should prevent duplicate submissions while creating', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + let resolveImport: (value: unknown) => void + mockImportDSL.mockImplementation(() => new Promise((resolve) => { + resolveImport = resolve + })) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // First call + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Second call should be ignored + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Third call should be ignored + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + // Only one call should be made + expect(mockImportDSL).toHaveBeenCalledTimes(1) + + // Resolve the first call + await act(async () => { + resolveImport!(createImportDSLResponse()) + }) + + vi.useRealTimers() + }) + }) + + describe('file reading', () => { + it('should read file content using FileReader', async () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + const fileContent = 'yaml content here' + const mockFile = new File([fileContent], 'test.pipeline', { type: 'application/octet-stream' }) + + await act(async () => { + result.current.handleFile(mockFile) + }) + + expect(result.current.currentFile).toBe(mockFile) + }) + + it('should clear file content when file is removed', async () => { + const { result } = renderHook( + () => useDSLImport({ activeTab: CreateFromDSLModalTab.FROM_FILE }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pipeline', { type: 'application/octet-stream' }) + + // Set file + await act(async () => { + result.current.handleFile(mockFile) + }) + + // Clear file + await act(async () => { + result.current.handleFile(undefined) + }) + + expect(result.current.currentFile).toBeUndefined() + }) + }) + + describe('navigation after import', () => { + it('should navigate to pipeline page after successful import', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + status: 'completed', + dataset_id: 'test-dataset-id', + })) + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith('/datasets/test-dataset-id/pipeline') + }) + + vi.useRealTimers() + }) + + it('should navigate to pipeline page after confirm success', async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + + mockImportDSL.mockResolvedValue(createImportDSLResponse({ + id: 'import-123', + status: 'pending', + })) + + mockImportDSLConfirm.mockResolvedValue({ + status: 'completed', + pipeline_id: 'pipeline-456', + dataset_id: 'confirm-dataset-id', + }) + + mockHandleCheckPluginDependencies.mockResolvedValue(undefined) + + const { result } = renderHook( + () => useDSLImport({ + activeTab: CreateFromDSLModalTab.FROM_URL, + dslUrl: 'https://example.com/test.pipeline', + }), + { wrapper: createWrapper() }, + ) + + // Trigger pending status + await act(async () => { + result.current.handleCreateApp() + vi.advanceTimersByTime(400) + }) + + await act(async () => { + vi.advanceTimersByTime(400) + }) + + // Call onDSLConfirm + await act(async () => { + result.current.onDSLConfirm() + }) + + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith('/datasets/confirm-dataset-id/pipeline') + }) + + vi.useRealTimers() + }) + }) + + describe('enum export', () => { + it('should export CreateFromDSLModalTab enum with correct values', () => { + expect(CreateFromDSLModalTab.FROM_FILE).toBe('from-file') + expect(CreateFromDSLModalTab.FROM_URL).toBe('from-url') + }) + }) +}) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts new file mode 100644 index 0000000000..87e55ea740 --- /dev/null +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -0,0 +1,218 @@ +'use client' +import { useDebounceFn } from 'ahooks' +import { useRouter } from 'next/navigation' +import { useCallback, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { ToastContext } from '@/app/components/base/toast' +import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' +import { + DSLImportMode, + DSLImportStatus, +} from '@/models/app' +import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline' + +export enum CreateFromDSLModalTab { + FROM_FILE = 'from-file', + FROM_URL = 'from-url', +} + +export type UseDSLImportOptions = { + activeTab?: CreateFromDSLModalTab + dslUrl?: string + onSuccess?: () => void + onClose?: () => void +} + +export type DSLVersions = { + importedVersion: string + systemVersion: string +} + +export const useDSLImport = ({ + activeTab = CreateFromDSLModalTab.FROM_FILE, + dslUrl = '', + onSuccess, + onClose, +}: UseDSLImportOptions) => { + const { push } = useRouter() + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + + const [currentFile, setDSLFile] = useState() + const [fileContent, setFileContent] = useState() + const [currentTab, setCurrentTab] = useState(activeTab) + const [dslUrlValue, setDslUrlValue] = useState(dslUrl) + const [showConfirmModal, setShowConfirmModal] = useState(false) + const [versions, setVersions] = useState() + const [importId, setImportId] = useState() + const [isConfirming, setIsConfirming] = useState(false) + + const { handleCheckPluginDependencies } = usePluginDependencies() + const isCreatingRef = useRef(false) + + const { mutateAsync: importDSL } = useImportPipelineDSL() + const { mutateAsync: importDSLConfirm } = useImportPipelineDSLConfirm() + + const readFile = useCallback((file: File) => { + const reader = new FileReader() + reader.onload = (event) => { + const content = event.target?.result + setFileContent(content as string) + } + reader.readAsText(file) + }, []) + + const handleFile = useCallback((file?: File) => { + setDSLFile(file) + if (file) + readFile(file) + if (!file) + setFileContent('') + }, [readFile]) + + const onCreate = useCallback(async () => { + if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) + return + if (currentTab === CreateFromDSLModalTab.FROM_URL && !dslUrlValue) + return + if (isCreatingRef.current) + return + + isCreatingRef.current = true + + let response + if (currentTab === CreateFromDSLModalTab.FROM_FILE) { + response = await importDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: fileContent || '', + }) + } + if (currentTab === CreateFromDSLModalTab.FROM_URL) { + response = await importDSL({ + mode: DSLImportMode.YAML_URL, + yaml_url: dslUrlValue || '', + }) + } + + if (!response) { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + isCreatingRef.current = false + return + } + + const { id, status, pipeline_id, dataset_id, imported_dsl_version, current_dsl_version } = response + + if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { + onSuccess?.() + onClose?.() + + notify({ + type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', + message: t(status === DSLImportStatus.COMPLETED ? 'creation.successTip' : 'creation.caution', { ns: 'datasetPipeline' }), + children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }), + }) + + if (pipeline_id) + await handleCheckPluginDependencies(pipeline_id, true) + + push(`/datasets/${dataset_id}/pipeline`) + isCreatingRef.current = false + } + else if (status === DSLImportStatus.PENDING) { + setVersions({ + importedVersion: imported_dsl_version ?? '', + systemVersion: current_dsl_version ?? '', + }) + onClose?.() + setTimeout(() => { + setShowConfirmModal(true) + }, 300) + setImportId(id) + isCreatingRef.current = false + } + else { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + isCreatingRef.current = false + } + }, [ + currentTab, + currentFile, + dslUrlValue, + fileContent, + importDSL, + notify, + t, + onSuccess, + onClose, + handleCheckPluginDependencies, + push, + ]) + + const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) + + const onDSLConfirm = useCallback(async () => { + if (!importId) + return + + setIsConfirming(true) + const response = await importDSLConfirm(importId) + setIsConfirming(false) + + if (!response) { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + return + } + + const { status, pipeline_id, dataset_id } = response + + if (status === DSLImportStatus.COMPLETED) { + onSuccess?.() + setShowConfirmModal(false) + + notify({ + type: 'success', + message: t('creation.successTip', { ns: 'datasetPipeline' }), + }) + + if (pipeline_id) + await handleCheckPluginDependencies(pipeline_id, true) + + push(`/datasets/${dataset_id}/pipeline`) + } + else if (status === DSLImportStatus.FAILED) { + notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) + } + }, [importId, importDSLConfirm, notify, t, onSuccess, handleCheckPluginDependencies, push]) + + const handleCancelConfirm = useCallback(() => { + setShowConfirmModal(false) + }, []) + + const buttonDisabled = useMemo(() => { + if (currentTab === CreateFromDSLModalTab.FROM_FILE) + return !currentFile + if (currentTab === CreateFromDSLModalTab.FROM_URL) + return !dslUrlValue + return false + }, [currentTab, currentFile, dslUrlValue]) + + return { + // State + currentFile, + currentTab, + dslUrlValue, + showConfirmModal, + versions, + buttonDisabled, + isConfirming, + + // Actions + setCurrentTab, + setDslUrlValue, + handleFile, + handleCreateApp, + onDSLConfirm, + handleCancelConfirm, + } +} diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx index 2d187010b8..079ea90687 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx @@ -1,24 +1,18 @@ 'use client' -import { useDebounceFn, useKeyPress } from 'ahooks' +import { useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' -import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' -import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' -import { - DSLImportMode, - DSLImportStatus, -} from '@/models/app' -import { useImportPipelineDSL, useImportPipelineDSLConfirm } from '@/service/use-pipeline' +import DSLConfirmModal from './dsl-confirm-modal' import Header from './header' +import { CreateFromDSLModalTab, useDSLImport } from './hooks/use-dsl-import' import Tab from './tab' import Uploader from './uploader' +export { CreateFromDSLModalTab } + type CreateFromDSLModalProps = { show: boolean onSuccess?: () => void @@ -27,11 +21,6 @@ type CreateFromDSLModalProps = { dslUrl?: string } -export enum CreateFromDSLModalTab { - FROM_FILE = 'from-file', - FROM_URL = 'from-url', -} - const CreateFromDSLModal = ({ show, onSuccess, @@ -39,149 +28,33 @@ const CreateFromDSLModal = ({ activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '', }: CreateFromDSLModalProps) => { - const { push } = useRouter() const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const [currentFile, setDSLFile] = useState() - const [fileContent, setFileContent] = useState() - const [currentTab, setCurrentTab] = useState(activeTab) - const [dslUrlValue, setDslUrlValue] = useState(dslUrl) - const [showErrorModal, setShowErrorModal] = useState(false) - const [versions, setVersions] = useState<{ importedVersion: string, systemVersion: string }>() - const [importId, setImportId] = useState() - const { handleCheckPluginDependencies } = usePluginDependencies() - const readFile = (file: File) => { - const reader = new FileReader() - reader.onload = function (event) { - const content = event.target?.result - setFileContent(content as string) - } - reader.readAsText(file) - } - - const handleFile = (file?: File) => { - setDSLFile(file) - if (file) - readFile(file) - if (!file) - setFileContent('') - } - - const isCreatingRef = useRef(false) - - const { mutateAsync: importDSL } = useImportPipelineDSL() - - const onCreate = async () => { - if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) - return - if (currentTab === CreateFromDSLModalTab.FROM_URL && !dslUrlValue) - return - if (isCreatingRef.current) - return - isCreatingRef.current = true - let response - if (currentTab === CreateFromDSLModalTab.FROM_FILE) { - response = await importDSL({ - mode: DSLImportMode.YAML_CONTENT, - yaml_content: fileContent || '', - }) - } - if (currentTab === CreateFromDSLModalTab.FROM_URL) { - response = await importDSL({ - mode: DSLImportMode.YAML_URL, - yaml_url: dslUrlValue || '', - }) - } - - if (!response) { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - isCreatingRef.current = false - return - } - const { id, status, pipeline_id, dataset_id, imported_dsl_version, current_dsl_version } = response - if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { - if (onSuccess) - onSuccess() - if (onClose) - onClose() - - notify({ - type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', - message: t(status === DSLImportStatus.COMPLETED ? 'creation.successTip' : 'creation.caution', { ns: 'datasetPipeline' }), - children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('newApp.appCreateDSLWarning', { ns: 'app' }), - }) - if (pipeline_id) - await handleCheckPluginDependencies(pipeline_id, true) - push(`/datasets/${dataset_id}/pipeline`) - isCreatingRef.current = false - } - else if (status === DSLImportStatus.PENDING) { - setVersions({ - importedVersion: imported_dsl_version ?? '', - systemVersion: current_dsl_version ?? '', - }) - if (onClose) - onClose() - setTimeout(() => { - setShowErrorModal(true) - }, 300) - setImportId(id) - isCreatingRef.current = false - } - else { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - isCreatingRef.current = false - } - } - - const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) - - useKeyPress('esc', () => { - if (show && !showErrorModal) - onClose() + const { + currentFile, + currentTab, + dslUrlValue, + showConfirmModal, + versions, + buttonDisabled, + isConfirming, + setCurrentTab, + setDslUrlValue, + handleFile, + handleCreateApp, + onDSLConfirm, + handleCancelConfirm, + } = useDSLImport({ + activeTab, + dslUrl, + onSuccess, + onClose, }) - const { mutateAsync: importDSLConfirm } = useImportPipelineDSLConfirm() - - const onDSLConfirm = async () => { - if (!importId) - return - const response = await importDSLConfirm(importId) - - if (!response) { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - return - } - - const { status, pipeline_id, dataset_id } = response - - if (status === DSLImportStatus.COMPLETED) { - if (onSuccess) - onSuccess() - if (onClose) - onClose() - - notify({ - type: 'success', - message: t('creation.successTip', { ns: 'datasetPipeline' }), - }) - if (pipeline_id) - await handleCheckPluginDependencies(pipeline_id, true) - push(`datasets/${dataset_id}/pipeline`) - } - else if (status === DSLImportStatus.FAILED) { - notify({ type: 'error', message: t('creation.errorTip', { ns: 'datasetPipeline' }) }) - } - } - - const buttonDisabled = useMemo(() => { - if (currentTab === CreateFromDSLModalTab.FROM_FILE) - return !currentFile - if (currentTab === CreateFromDSLModalTab.FROM_URL) - return !dslUrlValue - return false - }, [currentTab, currentFile, dslUrlValue]) + useKeyPress('esc', () => { + if (show && !showConfirmModal) + onClose() + }) return ( <> @@ -196,29 +69,25 @@ const CreateFromDSLModal = ({ setCurrentTab={setCurrentTab} />
- { - currentTab === CreateFromDSLModalTab.FROM_FILE && ( - - ) - } - { - currentTab === CreateFromDSLModalTab.FROM_URL && ( -
-
- DSL URL -
- setDslUrlValue(e.target.value)} - /> + {currentTab === CreateFromDSLModalTab.FROM_FILE && ( + + )} + {currentTab === CreateFromDSLModalTab.FROM_URL && ( +
+
+ DSL URL
- ) - } + setDslUrlValue(e.target.value)} + /> +
+ )}
- setShowErrorModal(false)} - className="w-[480px]" - > -
-
{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}
-
-
{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}
-
{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}
-
-
- {t('newApp.appCreateDSLErrorPart3', { ns: 'app' })} - {versions?.importedVersion} -
-
- {t('newApp.appCreateDSLErrorPart4', { ns: 'app' })} - {versions?.systemVersion} -
-
-
-
- - -
-
+ {showConfirmModal && ( + + )} ) } diff --git a/web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx b/web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx new file mode 100644 index 0000000000..4da20a7bf7 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/file-list-item.spec.tsx @@ -0,0 +1,334 @@ +import type { FileListItemProps } from './file-list-item' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' +import FileListItem from './file-list-item' + +// Mock theme hook - can be changed per test +let mockTheme = 'light' +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: mockTheme }), +})) + +// Mock theme types +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock SimplePieChart with dynamic import handling +vi.mock('next/dynamic', () => ({ + default: () => { + const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => ( +
+ Pie Chart: + {' '} + {percentage} + % +
+ ) + DynamicComponent.displayName = 'SimplePieChart' + return DynamicComponent + }, +})) + +// Mock DocumentFileIcon +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ name, extension, size }: { name: string, extension: string, size: string }) => ( +
+ Document Icon +
+ ), +})) + +describe('FileListItem', () => { + const createMockFile = (overrides: Partial = {}): File => ({ + name: 'test-document.pdf', + size: 1024 * 100, // 100KB + type: 'application/pdf', + lastModified: Date.now(), + ...overrides, + } as File) + + const createMockFileItem = (overrides: Partial = {}): FileItem => ({ + fileID: 'file-123', + file: createMockFile(overrides.file as Partial), + progress: PROGRESS_NOT_STARTED, + ...overrides, + }) + + const defaultProps: FileListItemProps = { + fileItem: createMockFileItem(), + onPreview: vi.fn(), + onRemove: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + mockTheme = 'light' + }) + + describe('rendering', () => { + it('should render the file item container', () => { + const { container } = render() + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('flex', 'h-12', 'items-center', 'rounded-lg') + }) + + it('should render document icon with correct props', () => { + render() + const icon = screen.getByTestId('document-icon') + expect(icon).toBeInTheDocument() + expect(icon).toHaveAttribute('data-name', 'test-document.pdf') + expect(icon).toHaveAttribute('data-extension', 'pdf') + expect(icon).toHaveAttribute('data-size', 'xl') + }) + + it('should render file name', () => { + render() + expect(screen.getByText('test-document.pdf')).toBeInTheDocument() + }) + + it('should render file extension in uppercase via CSS class', () => { + render() + const extensionSpan = screen.getByText('pdf') + expect(extensionSpan).toBeInTheDocument() + expect(extensionSpan).toHaveClass('uppercase') + }) + + it('should render file size', () => { + render() + // Default mock file is 100KB (1024 * 100 bytes) + expect(screen.getByText('100.00 KB')).toBeInTheDocument() + }) + + it('should render delete button', () => { + const { container } = render() + const deleteButton = container.querySelector('.cursor-pointer') + expect(deleteButton).toBeInTheDocument() + }) + }) + + describe('progress states', () => { + it('should show progress chart when uploading (0-99)', () => { + const fileItem = createMockFileItem({ progress: 50 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toBeInTheDocument() + expect(pieChart).toHaveAttribute('data-percentage', '50') + }) + + it('should show progress chart at 0%', () => { + const fileItem = createMockFileItem({ progress: 0 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-percentage', '0') + }) + + it('should not show progress chart when complete (100)', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_COMPLETE }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + + it('should not show progress chart when not started (-1)', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('error state', () => { + it('should show error indicator when progress is PROGRESS_ERROR', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_ERROR }) + const { container } = render() + + const errorIndicator = container.querySelector('.text-text-destructive') + expect(errorIndicator).toBeInTheDocument() + }) + + it('should not show error indicator when not in error state', () => { + const { container } = render() + const errorIndicator = container.querySelector('.text-text-destructive') + expect(errorIndicator).not.toBeInTheDocument() + }) + }) + + describe('theme handling', () => { + it('should use correct chart color for light theme', () => { + mockTheme = 'light' + const fileItem = createMockFileItem({ progress: 50 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#296dff') + expect(pieChart).toHaveAttribute('data-fill', '#296dff') + }) + + it('should use correct chart color for dark theme', () => { + mockTheme = 'dark' + const fileItem = createMockFileItem({ progress: 50 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#5289ff') + expect(pieChart).toHaveAttribute('data-fill', '#5289ff') + }) + }) + + describe('event handlers', () => { + it('should call onPreview when item is clicked with file id', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem({ + file: createMockFile({ id: 'uploaded-id' } as Partial), + }) + render() + + const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')! + fireEvent.click(item) + + expect(onPreview).toHaveBeenCalledTimes(1) + expect(onPreview).toHaveBeenCalledWith(fileItem.file) + }) + + it('should not call onPreview when file has no id', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem() + render() + + const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')! + fireEvent.click(item) + + expect(onPreview).not.toHaveBeenCalled() + }) + + it('should call onRemove when delete button is clicked', () => { + const onRemove = vi.fn() + const fileItem = createMockFileItem() + const { container } = render() + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onRemove).toHaveBeenCalledWith('file-123') + }) + + it('should stop propagation when delete button is clicked', () => { + const onPreview = vi.fn() + const onRemove = vi.fn() + const fileItem = createMockFileItem({ + file: createMockFile({ id: 'uploaded-id' } as Partial), + }) + const { container } = render() + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onPreview).not.toHaveBeenCalled() + }) + }) + + describe('file type handling', () => { + it('should handle files with multiple dots in name', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'my.document.file.docx' }), + }) + render() + + expect(screen.getByText('my.document.file.docx')).toBeInTheDocument() + expect(screen.getByText('docx')).toBeInTheDocument() + }) + + it('should handle files without extension', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'README' }), + }) + render() + + // File name appears once, and extension area shows empty string + expect(screen.getByText('README')).toBeInTheDocument() + }) + + it('should handle various file extensions', () => { + const extensions = ['txt', 'md', 'json', 'csv', 'xlsx'] + + extensions.forEach((ext) => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: `file.${ext}` }), + }) + const { unmount } = render() + expect(screen.getByText(ext)).toBeInTheDocument() + unmount() + }) + }) + }) + + describe('file size display', () => { + it('should display size in KB for small files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 }), + }) + render() + expect(screen.getByText('5.00 KB')).toBeInTheDocument() + }) + + it('should display size in MB for larger files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 * 1024 }), + }) + render() + expect(screen.getByText('5.00 MB')).toBeInTheDocument() + }) + }) + + describe('upload progress values', () => { + it('should show chart at progress 1', () => { + const fileItem = createMockFileItem({ progress: 1 }) + render() + expect(screen.getByTestId('pie-chart')).toBeInTheDocument() + }) + + it('should show chart at progress 99', () => { + const fileItem = createMockFileItem({ progress: 99 }) + render() + expect(screen.getByTestId('pie-chart')).toHaveAttribute('data-percentage', '99') + }) + + it('should not show chart at progress 100', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render() + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have proper shadow styling', () => { + const { container } = render() + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('shadow-xs') + }) + + it('should have proper border styling', () => { + const { container } = render() + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('border', 'border-components-panel-border') + }) + + it('should truncate long file names', () => { + const longFileName = 'this-is-a-very-long-file-name-that-should-be-truncated.pdf' + const fileItem = createMockFileItem({ + file: createMockFile({ name: longFileName }), + }) + render() + + const nameElement = screen.getByText(longFileName) + expect(nameElement).toHaveClass('truncate') + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx b/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx new file mode 100644 index 0000000000..d36773fa5c --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/file-list-item.tsx @@ -0,0 +1,89 @@ +'use client' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react' +import dynamic from 'next/dynamic' +import { useMemo } from 'react' +import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { formatFileSize, getFileExtension } from '@/utils/format' +import { PROGRESS_COMPLETE, PROGRESS_ERROR } from '../constants' + +const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) + +export type FileListItemProps = { + fileItem: FileItem + onPreview: (file: File) => void + onRemove: (fileID: string) => void +} + +const FileListItem = ({ + fileItem, + onPreview, + onRemove, +}: FileListItemProps) => { + const { theme } = useTheme() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + + const isUploading = fileItem.progress >= 0 && fileItem.progress < PROGRESS_COMPLETE + const isError = fileItem.progress === PROGRESS_ERROR + + const handleClick = () => { + if (fileItem.file?.id) + onPreview(fileItem.file) + } + + const handleRemove = (e: React.MouseEvent) => { + e.stopPropagation() + onRemove(fileItem.fileID) + } + + return ( +
+
+ +
+
+
+
+ {fileItem.file.name} +
+
+
+ {getFileExtension(fileItem.file.name)} + · + {formatFileSize(fileItem.file.size)} +
+
+
+ {isUploading && ( + + )} + {isError && ( + + )} + + + +
+
+ ) +} + +export default FileListItem diff --git a/web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx new file mode 100644 index 0000000000..112d61250b --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.spec.tsx @@ -0,0 +1,210 @@ +import type { RefObject } from 'react' +import type { UploadDropzoneProps } from './upload-dropzone' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import UploadDropzone from './upload-dropzone' + +// Helper to create mock ref objects for testing +const createMockRef = (value: T | null = null): RefObject => ({ current: value }) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: Record) => { + const translations: Record = { + 'stepOne.uploader.button': 'Drag and drop files, or', + 'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or', + 'stepOne.uploader.browse': 'Browse', + 'stepOne.uploader.tip': 'Supports {{supportTypes}}, Max {{size}}MB each, up to {{batchCount}} files at a time, {{totalCount}} files total', + } + let result = translations[key] || key + if (options && typeof options === 'object') { + Object.entries(options).forEach(([k, v]) => { + result = result.replace(`{{${k}}}`, String(v)) + }) + } + return result + }, + }), +})) + +describe('UploadDropzone', () => { + const defaultProps: UploadDropzoneProps = { + dropRef: createMockRef() as RefObject, + dragRef: createMockRef() as RefObject, + fileUploaderRef: createMockRef() as RefObject, + dragging: false, + supportBatchUpload: true, + supportTypesShowNames: 'PDF, DOCX, TXT', + fileUploadConfig: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + acceptTypes: ['.pdf', '.docx', '.txt'], + onSelectFile: vi.fn(), + onFileChange: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the dropzone container', () => { + const { container } = render() + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render hidden file input', () => { + render() + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toBeInTheDocument() + expect(input).toHaveClass('hidden') + expect(input).toHaveAttribute('type', 'file') + }) + + it('should render upload icon', () => { + render() + const icon = document.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + + it('should render browse label when extensions are allowed', () => { + render() + expect(screen.getByText('Browse')).toBeInTheDocument() + }) + + it('should not render browse label when no extensions allowed', () => { + render() + expect(screen.queryByText('Browse')).not.toBeInTheDocument() + }) + + it('should render file size and count limits', () => { + render() + const tipText = screen.getByText(/Supports.*Max.*15MB/i) + expect(tipText).toBeInTheDocument() + }) + }) + + describe('file input configuration', () => { + it('should allow multiple files when supportBatchUpload is true', () => { + render() + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('multiple') + }) + + it('should not allow multiple files when supportBatchUpload is false', () => { + render() + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).not.toHaveAttribute('multiple') + }) + + it('should set accept attribute with correct types', () => { + render() + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('accept', '.pdf,.docx') + }) + }) + + describe('text content', () => { + it('should show batch upload text when supportBatchUpload is true', () => { + render() + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should show single file text when supportBatchUpload is false', () => { + render() + expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument() + }) + }) + + describe('dragging state', () => { + it('should apply dragging styles when dragging is true', () => { + const { container } = render() + const dropzone = container.querySelector('[class*="border-components-dropzone-border-accent"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render drag overlay when dragging', () => { + const dragRef = createMockRef() + render(} />) + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).toBeInTheDocument() + }) + + it('should not render drag overlay when not dragging', () => { + render() + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).not.toBeInTheDocument() + }) + }) + + describe('event handlers', () => { + it('should call onSelectFile when browse label is clicked', () => { + const onSelectFile = vi.fn() + render() + + const browseLabel = screen.getByText('Browse') + fireEvent.click(browseLabel) + + expect(onSelectFile).toHaveBeenCalledTimes(1) + }) + + it('should call onFileChange when files are selected', () => { + const onFileChange = vi.fn() + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + fireEvent.change(input, { target: { files: [file] } }) + + expect(onFileChange).toHaveBeenCalledTimes(1) + }) + }) + + describe('refs', () => { + it('should attach dropRef to drop container', () => { + const dropRef = createMockRef() + render(} />) + expect(dropRef.current).toBeInstanceOf(HTMLDivElement) + }) + + it('should attach fileUploaderRef to input element', () => { + const fileUploaderRef = createMockRef() + render(} />) + expect(fileUploaderRef.current).toBeInstanceOf(HTMLInputElement) + }) + + it('should attach dragRef to overlay when dragging', () => { + const dragRef = createMockRef() + render(} />) + expect(dragRef.current).toBeInstanceOf(HTMLDivElement) + }) + }) + + describe('styling', () => { + it('should have base dropzone styling', () => { + const { container } = render() + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + expect(dropzone).toHaveClass('rounded-xl') + }) + + it('should have cursor-pointer on browse label', () => { + render() + const browseLabel = screen.getByText('Browse') + expect(browseLabel).toHaveClass('cursor-pointer') + }) + }) + + describe('accessibility', () => { + it('should have an accessible file input', () => { + render() + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('id', 'fileUploader') + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx new file mode 100644 index 0000000000..9fa577dace --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/components/upload-dropzone.tsx @@ -0,0 +1,84 @@ +'use client' +import type { RefObject } from 'react' +import type { FileUploadConfig } from '../hooks/use-file-upload' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { cn } from '@/utils/classnames' + +export type UploadDropzoneProps = { + dropRef: RefObject + dragRef: RefObject + fileUploaderRef: RefObject + dragging: boolean + supportBatchUpload: boolean + supportTypesShowNames: string + fileUploadConfig: FileUploadConfig + acceptTypes: string[] + onSelectFile: () => void + onFileChange: (e: React.ChangeEvent) => void +} + +const UploadDropzone = ({ + dropRef, + dragRef, + fileUploaderRef, + dragging, + supportBatchUpload, + supportTypesShowNames, + fileUploadConfig, + acceptTypes, + onSelectFile, + onFileChange, +}: UploadDropzoneProps) => { + const { t } = useTranslation() + + return ( + <> + +
+
+ + + {supportBatchUpload + ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) + : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} + {acceptTypes.length > 0 && ( + + )} + +
+
+ {t('stepOne.uploader.tip', { + ns: 'datasetCreation', + size: fileUploadConfig.file_size_limit, + supportTypes: supportTypesShowNames, + batchCount: fileUploadConfig.batch_count_limit, + totalCount: fileUploadConfig.file_upload_limit, + })} +
+ {dragging &&
} +
+ + ) +} + +export default UploadDropzone diff --git a/web/app/components/datasets/create/file-uploader/constants.ts b/web/app/components/datasets/create/file-uploader/constants.ts new file mode 100644 index 0000000000..cda2dae868 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/constants.ts @@ -0,0 +1,3 @@ +export const PROGRESS_NOT_STARTED = -1 +export const PROGRESS_ERROR = -2 +export const PROGRESS_COMPLETE = 100 diff --git a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx new file mode 100644 index 0000000000..222f038c84 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.spec.tsx @@ -0,0 +1,921 @@ +import type { ReactNode } from 'react' +import type { CustomFile, FileItem } from '@/models/datasets' +import { act, render, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ToastContext } from '@/app/components/base/toast' + +import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' +// Import after mocks +import { useFileUpload } from './use-file-upload' + +// Mock notify function +const mockNotify = vi.fn() +const mockClose = vi.fn() + +// Mock ToastContext +vi.mock('use-context-selector', async () => { + const actual = await vi.importActual('use-context-selector') + return { + ...actual, + useContext: vi.fn(() => ({ notify: mockNotify, close: mockClose })), + } +}) + +// Mock upload service +const mockUpload = vi.fn() +vi.mock('@/service/base', () => ({ + upload: (...args: unknown[]) => mockUpload(...args), +})) + +// Mock file upload config +const mockFileUploadConfig = { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, +} + +const mockSupportTypes = { + allowed_extensions: ['pdf', 'docx', 'txt', 'md'], +} + +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: () => ({ data: mockFileUploadConfig }), + useFileSupportTypes: () => ({ data: mockSupportTypes }), +})) + +// Mock i18n +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock locale +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/i18n-config/language', () => ({ + LanguagesSupported: ['en-US', 'zh-Hans'], +})) + +// Mock config +vi.mock('@/config', () => ({ + IS_CE_EDITION: false, +})) + +// Mock file upload error message +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFileUploadErrorMessage: (_e: unknown, defaultMsg: string) => defaultMsg, +})) + +const createWrapper = () => { + return ({ children }: { children: ReactNode }) => ( + + {children} + + ) +} + +describe('useFileUpload', () => { + const defaultOptions = { + fileList: [] as FileItem[], + prepareFileList: vi.fn(), + onFileUpdate: vi.fn(), + onFileListUpdate: vi.fn(), + onPreview: vi.fn(), + supportBatchUpload: true, + } + + beforeEach(() => { + vi.clearAllMocks() + mockUpload.mockReset() + // Default mock to return a resolved promise to avoid unhandled rejections + mockUpload.mockResolvedValue({ id: 'default-id' }) + mockNotify.mockReset() + }) + + describe('initialization', () => { + it('should initialize with default values', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(result.current.dragging).toBe(false) + expect(result.current.hideUpload).toBe(false) + expect(result.current.dropRef.current).toBeNull() + expect(result.current.dragRef.current).toBeNull() + expect(result.current.fileUploaderRef.current).toBeNull() + }) + + it('should set hideUpload true when not batch upload and has files', () => { + const { result } = renderHook( + () => useFileUpload({ + ...defaultOptions, + supportBatchUpload: false, + fileList: [{ fileID: 'file-1', file: {} as CustomFile, progress: 100 }], + }), + { wrapper: createWrapper() }, + ) + + expect(result.current.hideUpload).toBe(true) + }) + + it('should compute acceptTypes correctly', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(result.current.acceptTypes).toEqual(['.pdf', '.docx', '.txt', '.md']) + }) + + it('should compute supportTypesShowNames correctly', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('PDF') + expect(result.current.supportTypesShowNames).toContain('DOCX') + expect(result.current.supportTypesShowNames).toContain('TXT') + // 'md' is mapped to 'markdown' in the extensionMap + expect(result.current.supportTypesShowNames).toContain('MARKDOWN') + }) + + it('should set batch limit to 1 when not batch upload', () => { + const { result } = renderHook( + () => useFileUpload({ + ...defaultOptions, + supportBatchUpload: false, + }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.batch_count_limit).toBe(1) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(1) + }) + }) + + describe('selectHandle', () => { + it('should trigger click on file input', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + const mockClick = vi.fn() + const mockInput = { click: mockClick } as unknown as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.selectHandle() + }) + + expect(mockClick).toHaveBeenCalled() + }) + + it('should do nothing when file input ref is null', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + expect(() => { + act(() => { + result.current.selectHandle() + }) + }).not.toThrow() + }) + }) + + describe('handlePreview', () => { + it('should call onPreview when file has id', () => { + const onPreview = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onPreview }), + { wrapper: createWrapper() }, + ) + + const mockFile = { id: 'file-123', name: 'test.pdf', size: 1024 } as CustomFile + + act(() => { + result.current.handlePreview(mockFile) + }) + + expect(onPreview).toHaveBeenCalledWith(mockFile) + }) + + it('should not call onPreview when file has no id', () => { + const onPreview = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onPreview }), + { wrapper: createWrapper() }, + ) + + const mockFile = { name: 'test.pdf', size: 1024 } as CustomFile + + act(() => { + result.current.handlePreview(mockFile) + }) + + expect(onPreview).not.toHaveBeenCalled() + }) + }) + + describe('removeFile', () => { + it('should call onFileListUpdate with filtered list', () => { + const onFileListUpdate = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileListUpdate }), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.removeFile('file-to-remove') + }) + + expect(onFileListUpdate).toHaveBeenCalled() + }) + + it('should clear file input value', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + const mockInput = { value: 'some-file' } as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.removeFile('file-123') + }) + + expect(mockInput.value).toBe('') + }) + }) + + describe('fileChangeHandle', () => { + it('should handle valid files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should limit files to batch count', () => { + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const files = Array.from({ length: 10 }, (_, i) => + new File(['content'], `file${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { files }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + // Should be called with at most batch_count_limit files + if (prepareFileList.mock.calls.length > 0) { + const calledFiles = prepareFileList.mock.calls[0][0] + expect(calledFiles.length).toBeLessThanOrEqual(mockFileUploadConfig.batch_count_limit) + } + }) + + it('should reject invalid file types', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.exe', { type: 'application/x-msdownload' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should reject files exceeding size limit', () => { + const { result } = renderHook( + () => useFileUpload(defaultOptions), + { wrapper: createWrapper() }, + ) + + // Create a file larger than the limit (15MB) + const largeFile = new File([new ArrayBuffer(20 * 1024 * 1024)], 'large.pdf', { type: 'application/pdf' }) + + const event = { + target: { files: [largeFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should handle null files', () => { + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const event = { + target: { files: null }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(prepareFileList).not.toHaveBeenCalled() + }) + }) + + describe('drag and drop handlers', () => { + const TestDropzone = ({ options }: { options: typeof defaultOptions }) => { + const { + dropRef, + dragRef, + dragging, + } = useFileUpload(options) + + return ( +
+
+ {dragging &&
} +
+ {String(dragging)} +
+ ) + } + + it('should set dragging true on dragenter', async () => { + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + }) + + it('should handle dragover event', async () => { + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragOverEvent) + }) + + expect(dropzone).toBeInTheDocument() + }) + + it('should set dragging false on dragleave from drag overlay', async () => { + const { getByTestId, queryByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + + const dragOverlay = queryByTestId('drag-overlay') + if (dragOverlay) { + await act(async () => { + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'target', { value: dragOverlay }) + dropzone.dispatchEvent(dragLeaveEvent) + }) + } + }) + + it('should handle drop with files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => mockFile, + webkitGetAsEntry: () => null, + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop without dataTransfer', async () => { + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { value: null }) + dropzone.dispatchEvent(dropEvent) + }) + + expect(prepareFileList).not.toHaveBeenCalled() + }) + + it('should limit to single file on drop when supportBatchUpload is false', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + const files = [ + new File(['content1'], 'test1.pdf', { type: 'application/pdf' }), + new File(['content2'], 'test2.pdf', { type: 'application/pdf' }), + ] + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: files.map(f => ({ + getAsFile: () => f, + webkitGetAsEntry: () => null, + })), + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + if (prepareFileList.mock.calls.length > 0) { + const calledFiles = prepareFileList.mock.calls[0][0] + expect(calledFiles.length).toBe(1) + } + }) + }) + + it('should handle drop with FileSystemFileEntry', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => mockFile, + webkitGetAsEntry: () => ({ + isFile: true, + isDirectory: false, + file: (callback: (file: File) => void) => callback(mockFile), + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop with FileSystemDirectoryEntry', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + const prepareFileList = vi.fn() + const mockFile = new File(['content'], 'nested.pdf', { type: 'application/pdf' }) + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + let callCount = 0 + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => null, + webkitGetAsEntry: () => ({ + isFile: false, + isDirectory: true, + name: 'folder', + createReader: () => ({ + readEntries: (callback: (entries: Array<{ isFile: boolean, isDirectory: boolean, name?: string, file?: (cb: (f: File) => void) => void }>) => void) => { + // First call returns file entry, second call returns empty (signals end) + if (callCount === 0) { + callCount++ + callback([{ + isFile: true, + isDirectory: false, + name: 'nested.pdf', + file: (cb: (f: File) => void) => cb(mockFile), + }]) + } + else { + callback([]) + } + }, + }), + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(prepareFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop with empty directory', async () => { + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => null, + webkitGetAsEntry: () => ({ + isFile: false, + isDirectory: true, + name: 'empty-folder', + createReader: () => ({ + readEntries: (callback: (entries: never[]) => void) => { + callback([]) + }, + }), + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + // Should not prepare file list if no valid files + await new Promise(resolve => setTimeout(resolve, 100)) + }) + + it('should handle entry that is neither file nor directory', async () => { + const prepareFileList = vi.fn() + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: DataTransfer | null } + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { + items: [{ + getAsFile: () => null, + webkitGetAsEntry: () => ({ + isFile: false, + isDirectory: false, + }), + }], + }, + }) + dropzone.dispatchEvent(dropEvent) + }) + + // Should not throw and should handle gracefully + await new Promise(resolve => setTimeout(resolve, 100)) + }) + }) + + describe('file upload', () => { + it('should call upload with correct parameters', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id', name: 'test.pdf' }) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + }) + + it('should update progress during upload', async () => { + let progressCallback: ((e: ProgressEvent) => void) | undefined + + mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => { + progressCallback = options.onprogress + return { id: 'uploaded-id' } + }) + + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + + if (progressCallback) { + act(() => { + progressCallback!({ + lengthComputable: true, + loaded: 50, + total: 100, + } as ProgressEvent) + }) + + expect(onFileUpdate).toHaveBeenCalled() + } + }) + + it('should handle upload error', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + }) + + it('should update file with PROGRESS_COMPLETE on success', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id', name: 'test.pdf' }) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const completeCalls = onFileUpdate.mock.calls.filter( + ([, progress]) => progress === PROGRESS_COMPLETE, + ) + expect(completeCalls.length).toBeGreaterThan(0) + }) + }) + + it('should update file with PROGRESS_ERROR on failure', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + const onFileUpdate = vi.fn() + + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, onFileUpdate }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const errorCalls = onFileUpdate.mock.calls.filter( + ([, progress]) => progress === PROGRESS_ERROR, + ) + expect(errorCalls.length).toBeGreaterThan(0) + }) + }) + }) + + describe('file count validation', () => { + it('should reject when total files exceed limit', () => { + const existingFiles: FileItem[] = Array.from({ length: 8 }, (_, i) => ({ + fileID: `existing-${i}`, + file: { name: `existing-${i}.pdf`, size: 1024 } as CustomFile, + progress: 100, + })) + + const { result } = renderHook( + () => useFileUpload({ + ...defaultOptions, + fileList: existingFiles, + }), + { wrapper: createWrapper() }, + ) + + const files = Array.from({ length: 5 }, (_, i) => + new File(['content'], `new-${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { files }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + }) + + describe('progress constants', () => { + it('should use PROGRESS_NOT_STARTED for new files', async () => { + mockUpload.mockResolvedValue({ id: 'file-id' }) + + const prepareFileList = vi.fn() + const { result } = renderHook( + () => useFileUpload({ ...defaultOptions, prepareFileList }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + if (prepareFileList.mock.calls.length > 0) { + const files = prepareFileList.mock.calls[0][0] + expect(files[0].progress).toBe(PROGRESS_NOT_STARTED) + } + }) + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts new file mode 100644 index 0000000000..e097bab755 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts @@ -0,0 +1,351 @@ +'use client' +import type { RefObject } from 'react' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' +import { ToastContext } from '@/app/components/base/toast' +import { IS_CE_EDITION } from '@/config' +import { useLocale } from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import { upload } from '@/service/base' +import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common' +import { getFileExtension } from '@/utils/format' +import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' + +export type FileUploadConfig = { + file_size_limit: number + batch_count_limit: number + file_upload_limit: number +} + +export type UseFileUploadOptions = { + fileList: FileItem[] + prepareFileList: (files: FileItem[]) => void + onFileUpdate: (fileItem: FileItem, progress: number, list: FileItem[]) => void + onFileListUpdate?: (files: FileItem[]) => void + onPreview: (file: File) => void + supportBatchUpload?: boolean + /** + * Optional list of allowed file extensions. If not provided, fetches from API. + * Pass this when you need custom extension filtering instead of using the global config. + */ + allowedExtensions?: string[] +} + +export type UseFileUploadReturn = { + // Refs + dropRef: RefObject + dragRef: RefObject + fileUploaderRef: RefObject + + // State + dragging: boolean + + // Config + fileUploadConfig: FileUploadConfig + acceptTypes: string[] + supportTypesShowNames: string + hideUpload: boolean + + // Handlers + selectHandle: () => void + fileChangeHandle: (e: React.ChangeEvent) => void + removeFile: (fileID: string) => void + handlePreview: (file: File) => void +} + +type FileWithPath = { + relativePath?: string +} & File + +export const useFileUpload = ({ + fileList, + prepareFileList, + onFileUpdate, + onFileListUpdate, + onPreview, + supportBatchUpload = false, + allowedExtensions, +}: UseFileUploadOptions): UseFileUploadReturn => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const locale = useLocale() + + const [dragging, setDragging] = useState(false) + const dropRef = useRef(null) + const dragRef = useRef(null) + const fileUploaderRef = useRef(null) + const fileListRef = useRef([]) + + const hideUpload = !supportBatchUpload && fileList.length > 0 + + const { data: fileUploadConfigResponse } = useFileUploadConfig() + const { data: supportFileTypesResponse } = useFileSupportTypes() + // Use provided allowedExtensions or fetch from API + const supportTypes = useMemo( + () => allowedExtensions ?? supportFileTypesResponse?.allowed_extensions ?? [], + [allowedExtensions, supportFileTypesResponse?.allowed_extensions], + ) + + const supportTypesShowNames = useMemo(() => { + const extensionMap: { [key: string]: string } = { + md: 'markdown', + pptx: 'pptx', + htm: 'html', + xlsx: 'xlsx', + docx: 'docx', + } + + return [...supportTypes] + .map(item => extensionMap[item] || item) + .map(item => item.toLowerCase()) + .filter((item, index, self) => self.indexOf(item) === index) + .map(item => item.toUpperCase()) + .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') + }, [supportTypes, locale]) + + const acceptTypes = useMemo(() => supportTypes.map((ext: string) => `.${ext}`), [supportTypes]) + + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, + file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, + }), [fileUploadConfigResponse, supportBatchUpload]) + + const isValid = useCallback((file: File) => { + const { size } = file + const ext = `.${getFileExtension(file.name)}` + const isValidType = acceptTypes.includes(ext.toLowerCase()) + if (!isValidType) + notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) }) + + const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 + if (!isValidSize) + notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) }) + + return isValidType && isValidSize + }, [fileUploadConfig, notify, t, acceptTypes]) + + const fileUpload = useCallback(async (fileItem: FileItem): Promise => { + const formData = new FormData() + formData.append('file', fileItem.file) + const onProgress = (e: ProgressEvent) => { + if (e.lengthComputable) { + const percent = Math.floor(e.loaded / e.total * 100) + onFileUpdate(fileItem, percent, fileListRef.current) + } + } + + return upload({ + xhr: new XMLHttpRequest(), + data: formData, + onprogress: onProgress, + }, false, undefined, '?source=datasets') + .then((res) => { + const completeFile = { + fileID: fileItem.fileID, + file: res as unknown as File, + progress: PROGRESS_NOT_STARTED, + } + const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID) + fileListRef.current[index] = completeFile + onFileUpdate(completeFile, PROGRESS_COMPLETE, fileListRef.current) + return Promise.resolve({ ...completeFile }) + }) + .catch((e) => { + const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t) + notify({ type: 'error', message: errorMessage }) + onFileUpdate(fileItem, PROGRESS_ERROR, fileListRef.current) + return Promise.resolve({ ...fileItem }) + }) + .finally() + }, [notify, onFileUpdate, t]) + + const uploadBatchFiles = useCallback((bFiles: FileItem[]) => { + bFiles.forEach(bf => (bf.progress = 0)) + return Promise.all(bFiles.map(fileUpload)) + }, [fileUpload]) + + const uploadMultipleFiles = useCallback(async (files: FileItem[]) => { + const batchCountLimit = fileUploadConfig.batch_count_limit + const length = files.length + let start = 0 + let end = 0 + + while (start < length) { + if (start + batchCountLimit > length) + end = length + else + end = start + batchCountLimit + const bFiles = files.slice(start, end) + await uploadBatchFiles(bFiles) + start = end + } + }, [fileUploadConfig, uploadBatchFiles]) + + const initialUpload = useCallback((files: File[]) => { + const filesCountLimit = fileUploadConfig.file_upload_limit + if (!files.length) + return false + + if (files.length + fileList.length > filesCountLimit && !IS_CE_EDITION) { + notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) }) + return false + } + + const preparedFiles = files.map((file, index) => ({ + fileID: `file${index}-${Date.now()}`, + file, + progress: PROGRESS_NOT_STARTED, + })) + const newFiles = [...fileListRef.current, ...preparedFiles] + prepareFileList(newFiles) + fileListRef.current = newFiles + uploadMultipleFiles(preparedFiles) + }, [prepareFileList, uploadMultipleFiles, notify, t, fileList, fileUploadConfig]) + + const traverseFileEntry = useCallback( + (entry: FileSystemEntry, prefix = ''): Promise => { + return new Promise((resolve) => { + if (entry.isFile) { + (entry as FileSystemFileEntry).file((file: FileWithPath) => { + file.relativePath = `${prefix}${file.name}` + resolve([file]) + }) + } + else if (entry.isDirectory) { + const reader = (entry as FileSystemDirectoryEntry).createReader() + const entries: FileSystemEntry[] = [] + const read = () => { + reader.readEntries(async (results: FileSystemEntry[]) => { + if (!results.length) { + const files = await Promise.all( + entries.map(ent => + traverseFileEntry(ent, `${prefix}${entry.name}/`), + ), + ) + resolve(files.flat()) + } + else { + entries.push(...results) + read() + } + }) + } + read() + } + else { + resolve([]) + } + }) + }, + [], + ) + + const handleDragEnter = useCallback((e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target !== dragRef.current) + setDragging(true) + }, []) + + const handleDragOver = useCallback((e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + }, []) + + const handleDragLeave = useCallback((e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target === dragRef.current) + setDragging(false) + }, []) + + const handleDrop = useCallback( + async (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + if (!e.dataTransfer) + return + const nested = await Promise.all( + Array.from(e.dataTransfer.items).map((it) => { + const entry = (it as DataTransferItem & { webkitGetAsEntry?: () => FileSystemEntry | null }).webkitGetAsEntry?.() + if (entry) + return traverseFileEntry(entry) + const f = it.getAsFile?.() + return f ? Promise.resolve([f as FileWithPath]) : Promise.resolve([]) + }), + ) + let files = nested.flat() + if (!supportBatchUpload) + files = files.slice(0, 1) + files = files.slice(0, fileUploadConfig.batch_count_limit) + const valid = files.filter(isValid) + initialUpload(valid) + }, + [initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig], + ) + + const selectHandle = useCallback(() => { + if (fileUploaderRef.current) + fileUploaderRef.current.click() + }, []) + + const removeFile = useCallback((fileID: string) => { + if (fileUploaderRef.current) + fileUploaderRef.current.value = '' + + fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID) + onFileListUpdate?.([...fileListRef.current]) + }, [onFileListUpdate]) + + const fileChangeHandle = useCallback((e: React.ChangeEvent) => { + let files = Array.from(e.target.files ?? []) as File[] + files = files.slice(0, fileUploadConfig.batch_count_limit) + initialUpload(files.filter(isValid)) + }, [isValid, initialUpload, fileUploadConfig]) + + const handlePreview = useCallback((file: File) => { + if (file?.id) + onPreview(file) + }, [onPreview]) + + useEffect(() => { + const dropArea = dropRef.current + dropArea?.addEventListener('dragenter', handleDragEnter) + dropArea?.addEventListener('dragover', handleDragOver) + dropArea?.addEventListener('dragleave', handleDragLeave) + dropArea?.addEventListener('drop', handleDrop) + return () => { + dropArea?.removeEventListener('dragenter', handleDragEnter) + dropArea?.removeEventListener('dragover', handleDragOver) + dropArea?.removeEventListener('dragleave', handleDragLeave) + dropArea?.removeEventListener('drop', handleDrop) + } + }, [handleDragEnter, handleDragOver, handleDragLeave, handleDrop]) + + return { + // Refs + dropRef, + dragRef, + fileUploaderRef, + + // State + dragging, + + // Config + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + + // Handlers + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } +} diff --git a/web/app/components/datasets/create/file-uploader/index.spec.tsx b/web/app/components/datasets/create/file-uploader/index.spec.tsx new file mode 100644 index 0000000000..91f65652f3 --- /dev/null +++ b/web/app/components/datasets/create/file-uploader/index.spec.tsx @@ -0,0 +1,278 @@ +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_NOT_STARTED } from './constants' +import FileUploader from './index' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'stepOne.uploader.title': 'Upload Files', + 'stepOne.uploader.button': 'Drag and drop files, or', + 'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or', + 'stepOne.uploader.browse': 'Browse', + 'stepOne.uploader.tip': 'Supports various file types', + } + return translations[key] || key + }, + }), +})) + +// Mock ToastContext +const mockNotify = vi.fn() +vi.mock('use-context-selector', async () => { + const actual = await vi.importActual('use-context-selector') + return { + ...actual, + useContext: vi.fn(() => ({ notify: mockNotify })), + } +}) + +// Mock services +vi.mock('@/service/base', () => ({ + upload: vi.fn().mockResolvedValue({ id: 'uploaded-id' }), +})) + +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: () => ({ + data: { file_size_limit: 15, batch_count_limit: 5, file_upload_limit: 10 }, + }), + useFileSupportTypes: () => ({ + data: { allowed_extensions: ['pdf', 'docx', 'txt'] }, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/i18n-config/language', () => ({ + LanguagesSupported: ['en-US', 'zh-Hans'], +})) + +vi.mock('@/config', () => ({ + IS_CE_EDITION: false, +})) + +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFileUploadErrorMessage: () => 'Upload error', +})) + +// Mock theme +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: 'light' }), +})) + +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock DocumentFileIcon - uses relative path from file-list-item.tsx +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ extension }: { extension: string }) =>
{extension}
, +})) + +// Mock SimplePieChart +vi.mock('next/dynamic', () => ({ + default: () => { + const Component = ({ percentage }: { percentage: number }) => ( +
+ {percentage} + % +
+ ) + return Component + }, +})) + +describe('FileUploader', () => { + const createMockFile = (overrides: Partial = {}): File => ({ + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + ...overrides, + } as File) + + const createMockFileItem = (overrides: Partial = {}): FileItem => ({ + fileID: `file-${Date.now()}`, + file: createMockFile(overrides.file as Partial), + progress: PROGRESS_NOT_STARTED, + ...overrides, + }) + + const defaultProps = { + fileList: [] as FileItem[], + prepareFileList: vi.fn(), + onFileUpdate: vi.fn(), + onFileListUpdate: vi.fn(), + onPreview: vi.fn(), + supportBatchUpload: true, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the component', () => { + render() + expect(screen.getByText('Upload Files')).toBeInTheDocument() + }) + + it('should render dropzone when no files', () => { + render() + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should render browse button', () => { + render() + expect(screen.getByText('Browse')).toBeInTheDocument() + }) + + it('should apply custom title className', () => { + render() + const title = screen.getByText('Upload Files') + expect(title).toHaveClass('custom-class') + }) + }) + + describe('file list rendering', () => { + it('should render file items when fileList has items', () => { + const fileList = [ + createMockFileItem({ file: createMockFile({ name: 'file1.pdf' }) }), + createMockFileItem({ file: createMockFile({ name: 'file2.pdf' }) }), + ] + + render() + + expect(screen.getByText('file1.pdf')).toBeInTheDocument() + expect(screen.getByText('file2.pdf')).toBeInTheDocument() + }) + + it('should render document icons for files', () => { + const fileList = [createMockFileItem()] + render() + + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) + + describe('batch upload mode', () => { + it('should show dropzone with batch upload enabled', () => { + render() + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should show single file text when batch upload disabled', () => { + render() + expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument() + }) + + it('should hide dropzone when not batch upload and has files', () => { + const fileList = [createMockFileItem()] + render() + + expect(screen.queryByText(/Drag and drop/i)).not.toBeInTheDocument() + }) + }) + + describe('event handlers', () => { + it('should handle file preview click', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem({ + file: createMockFile({ id: 'file-id' } as Partial), + }) + + const { container } = render() + + // Find the file list item container by its class pattern + const fileElement = container.querySelector('[class*="flex h-12"]') + if (fileElement) + fireEvent.click(fileElement) + + expect(onPreview).toHaveBeenCalledWith(fileItem.file) + }) + + it('should handle file remove click', () => { + const onFileListUpdate = vi.fn() + const fileItem = createMockFileItem() + + const { container } = render( + , + ) + + // Find the delete button (the span with cursor-pointer containing the icon) + const deleteButtons = container.querySelectorAll('[class*="cursor-pointer"]') + // Get the last one which should be the delete button (not the browse label) + const deleteButton = deleteButtons[deleteButtons.length - 1] + if (deleteButton) + fireEvent.click(deleteButton) + + expect(onFileListUpdate).toHaveBeenCalled() + }) + + it('should handle browse button click', () => { + render() + + // The browse label should trigger file input click + const browseLabel = screen.getByText('Browse') + expect(browseLabel).toHaveClass('cursor-pointer') + }) + }) + + describe('upload progress', () => { + it('should show progress chart for uploading files', () => { + const fileItem = createMockFileItem({ progress: 50 }) + render() + + expect(screen.getByTestId('pie-chart')).toBeInTheDocument() + expect(screen.getByText('50%')).toBeInTheDocument() + }) + + it('should not show progress chart for completed files', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + + it('should not show progress chart for not started files', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('multiple files', () => { + it('should render all files in the list', () => { + const fileList = [ + createMockFileItem({ fileID: 'f1', file: createMockFile({ name: 'doc1.pdf' }) }), + createMockFileItem({ fileID: 'f2', file: createMockFile({ name: 'doc2.docx' }) }), + createMockFileItem({ fileID: 'f3', file: createMockFile({ name: 'doc3.txt' }) }), + ] + + render() + + expect(screen.getByText('doc1.pdf')).toBeInTheDocument() + expect(screen.getByText('doc2.docx')).toBeInTheDocument() + expect(screen.getByText('doc3.txt')).toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have correct container width', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('w-[640px]') + }) + + it('should have proper spacing', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('mb-5') + }) + }) +}) diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 781b97200a..b649554a12 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -1,23 +1,10 @@ 'use client' import type { CustomFile as File, FileItem } from '@/models/datasets' -import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import SimplePieChart from '@/app/components/base/simple-pie-chart' -import { ToastContext } from '@/app/components/base/toast' -import { IS_CE_EDITION } from '@/config' - -import { useLocale } from '@/context/i18n' -import useTheme from '@/hooks/use-theme' -import { LanguagesSupported } from '@/i18n-config/language' -import { upload } from '@/service/base' -import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common' -import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' -import DocumentFileIcon from '../../common/document-file-icon' +import FileListItem from './components/file-list-item' +import UploadDropzone from './components/upload-dropzone' +import { useFileUpload } from './hooks/use-file-upload' type IFileUploaderProps = { fileList: FileItem[] @@ -39,358 +26,62 @@ const FileUploader = ({ supportBatchUpload = false, }: IFileUploaderProps) => { const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const locale = useLocale() - const [dragging, setDragging] = useState(false) - const dropRef = useRef(null) - const dragRef = useRef(null) - const fileUploader = useRef(null) - const hideUpload = !supportBatchUpload && fileList.length > 0 - const { data: fileUploadConfigResponse } = useFileUploadConfig() - const { data: supportFileTypesResponse } = useFileSupportTypes() - const supportTypes = supportFileTypesResponse?.allowed_extensions || [] - const supportTypesShowNames = (() => { - const extensionMap: { [key: string]: string } = { - md: 'markdown', - pptx: 'pptx', - htm: 'html', - xlsx: 'xlsx', - docx: 'docx', - } - - return [...supportTypes] - .map(item => extensionMap[item] || item) // map to standardized extension - .map(item => item.toLowerCase()) // convert to lower case - .filter((item, index, self) => self.indexOf(item) === index) // remove duplicates - .map(item => item.toUpperCase()) // convert to upper case - .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') - })() - const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => ({ - file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, - batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, - file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, - }), [fileUploadConfigResponse, supportBatchUpload]) - - const fileListRef = useRef([]) - - // utils - const getFileType = (currentFile: File) => { - if (!currentFile) - return '' - - const arr = currentFile.name.split('.') - return arr[arr.length - 1] - } - - const getFileSize = (size: number) => { - if (size / 1024 < 10) - return `${(size / 1024).toFixed(2)}KB` - - return `${(size / 1024 / 1024).toFixed(2)}MB` - } - - const isValid = useCallback((file: File) => { - const { size } = file - const ext = `.${getFileType(file)}` - const isValidType = ACCEPTS.includes(ext.toLowerCase()) - if (!isValidType) - notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) }) - - const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 - if (!isValidSize) - notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) }) - - return isValidType && isValidSize - }, [fileUploadConfig, notify, t, ACCEPTS]) - - const fileUpload = useCallback(async (fileItem: FileItem): Promise => { - const formData = new FormData() - formData.append('file', fileItem.file) - const onProgress = (e: ProgressEvent) => { - if (e.lengthComputable) { - const percent = Math.floor(e.loaded / e.total * 100) - onFileUpdate(fileItem, percent, fileListRef.current) - } - } - - return upload({ - xhr: new XMLHttpRequest(), - data: formData, - onprogress: onProgress, - }, false, undefined, '?source=datasets') - .then((res) => { - const completeFile = { - fileID: fileItem.fileID, - file: res as unknown as File, - progress: -1, - } - const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID) - fileListRef.current[index] = completeFile - onFileUpdate(completeFile, 100, fileListRef.current) - return Promise.resolve({ ...completeFile }) - }) - .catch((e) => { - const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t) - notify({ type: 'error', message: errorMessage }) - onFileUpdate(fileItem, -2, fileListRef.current) - return Promise.resolve({ ...fileItem }) - }) - .finally() - }, [fileListRef, notify, onFileUpdate, t]) - - const uploadBatchFiles = useCallback((bFiles: FileItem[]) => { - bFiles.forEach(bf => (bf.progress = 0)) - return Promise.all(bFiles.map(fileUpload)) - }, [fileUpload]) - - const uploadMultipleFiles = useCallback(async (files: FileItem[]) => { - const batchCountLimit = fileUploadConfig.batch_count_limit - const length = files.length - let start = 0 - let end = 0 - - while (start < length) { - if (start + batchCountLimit > length) - end = length - else - end = start + batchCountLimit - const bFiles = files.slice(start, end) - await uploadBatchFiles(bFiles) - start = end - } - }, [fileUploadConfig, uploadBatchFiles]) - - const initialUpload = useCallback((files: File[]) => { - const filesCountLimit = fileUploadConfig.file_upload_limit - if (!files.length) - return false - - if (files.length + fileList.length > filesCountLimit && !IS_CE_EDITION) { - notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) }) - return false - } - - const preparedFiles = files.map((file, index) => ({ - fileID: `file${index}-${Date.now()}`, - file, - progress: -1, - })) - const newFiles = [...fileListRef.current, ...preparedFiles] - prepareFileList(newFiles) - fileListRef.current = newFiles - uploadMultipleFiles(preparedFiles) - }, [prepareFileList, uploadMultipleFiles, notify, t, fileList, fileUploadConfig]) - - const handleDragEnter = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target !== dragRef.current) - setDragging(true) - } - const handleDragOver = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - } - const handleDragLeave = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target === dragRef.current) - setDragging(false) - } - type FileWithPath = { - relativePath?: string - } & File - const traverseFileEntry = useCallback( - (entry: any, prefix = ''): Promise => { - return new Promise((resolve) => { - if (entry.isFile) { - entry.file((file: FileWithPath) => { - file.relativePath = `${prefix}${file.name}` - resolve([file]) - }) - } - else if (entry.isDirectory) { - const reader = entry.createReader() - const entries: any[] = [] - const read = () => { - reader.readEntries(async (results: FileSystemEntry[]) => { - if (!results.length) { - const files = await Promise.all( - entries.map(ent => - traverseFileEntry(ent, `${prefix}${entry.name}/`), - ), - ) - resolve(files.flat()) - } - else { - entries.push(...results) - read() - } - }) - } - read() - } - else { - resolve([]) - } - }) - }, - [], - ) - - const handleDrop = useCallback( - async (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - setDragging(false) - if (!e.dataTransfer) - return - const nested = await Promise.all( - Array.from(e.dataTransfer.items).map((it) => { - const entry = (it as any).webkitGetAsEntry?.() - if (entry) - return traverseFileEntry(entry) - const f = it.getAsFile?.() - return f ? Promise.resolve([f]) : Promise.resolve([]) - }), - ) - let files = nested.flat() - if (!supportBatchUpload) - files = files.slice(0, 1) - files = files.slice(0, fileUploadConfig.batch_count_limit) - const valid = files.filter(isValid) - initialUpload(valid) - }, - [initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig], - ) - const selectHandle = () => { - if (fileUploader.current) - fileUploader.current.click() - } - - const removeFile = (fileID: string) => { - if (fileUploader.current) - fileUploader.current.value = '' - - fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID) - onFileListUpdate?.([...fileListRef.current]) - } - const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - let files = Array.from(e.target.files ?? []) as File[] - files = files.slice(0, fileUploadConfig.batch_count_limit) - initialUpload(files.filter(isValid)) - }, [isValid, initialUpload, fileUploadConfig]) - - const { theme } = useTheme() - const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) - - useEffect(() => { - dropRef.current?.addEventListener('dragenter', handleDragEnter) - dropRef.current?.addEventListener('dragover', handleDragOver) - dropRef.current?.addEventListener('dragleave', handleDragLeave) - dropRef.current?.addEventListener('drop', handleDrop) - return () => { - dropRef.current?.removeEventListener('dragenter', handleDragEnter) - dropRef.current?.removeEventListener('dragover', handleDragOver) - dropRef.current?.removeEventListener('dragleave', handleDragLeave) - dropRef.current?.removeEventListener('drop', handleDrop) - } - }, [handleDrop]) + const { + dropRef, + dragRef, + fileUploaderRef, + dragging, + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } = useFileUpload({ + fileList, + prepareFileList, + onFileUpdate, + onFileListUpdate, + onPreview, + supportBatchUpload, + }) return (
+
+ {t('stepOne.uploader.title', { ns: 'datasetCreation' })} +
+ {!hideUpload && ( - )} -
{t('stepOne.uploader.title', { ns: 'datasetCreation' })}
- - {!hideUpload && ( -
-
- - - - {supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} - {supportTypes.length > 0 && ( - - )} - -
-
- {t('stepOne.uploader.tip', { - ns: 'datasetCreation', - size: fileUploadConfig.file_size_limit, - supportTypes: supportTypesShowNames, - batchCount: fileUploadConfig.batch_count_limit, - totalCount: fileUploadConfig.file_upload_limit, - })} -
- {dragging &&
} + {fileList.length > 0 && ( +
+ {fileList.map(fileItem => ( + + ))}
)} -
- - {fileList.map((fileItem, index) => ( -
fileItem.file?.id && onPreview(fileItem.file)} - className={cn( - 'flex h-12 max-w-[640px] items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary shadow-xs', - // 'border-state-destructive-border bg-state-destructive-hover', - )} - > -
- -
-
-
-
{fileItem.file.name}
-
-
- {getFileType(fileItem.file)} - · - {getFileSize(fileItem.file.size)} - {/* · - 10k characters */} -
-
-
- {/* - - */} - {(fileItem.progress < 100 && fileItem.progress >= 0) && ( - //
{`${fileItem.progress}%`}
- - )} - { - e.stopPropagation() - removeFile(fileItem.fileID) - }} - > - - -
-
- ))} -
) } diff --git a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx index 84d742d734..0beda8f5c8 100644 --- a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx +++ b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx @@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC = ({
))} { - showSummaryIndexSetting && ( + showSummaryIndexSetting && IS_CE_EDITION && (
= ({
))} { - showSummaryIndexSetting && ( + showSummaryIndexSetting && IS_CE_EDITION && (
= {}): SimpleDocumentDetail => ({ + id: 'doc-1', + position: 1, + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: {}, + dataset_process_rule_id: 'rule-1', + dataset_id: 'dataset-1', + batch: 'batch-1', + name: 'test-document.txt', + created_from: 'web', + created_by: 'user-1', + created_at: Date.now(), + tokens: 100, + indexing_status: 'completed', + error: null, + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: Date.now(), + doc_type: null, + doc_metadata: undefined, + doc_language: 'en', + display_status: 'available', + word_count: 100, + hit_count: 10, + doc_form: 'text_model', + ...overrides, +}) as unknown as SimpleDocumentDetail + +describe('DocumentSourceIcon', () => { + describe('Rendering', () => { + it('should render without crashing', () => { + const doc = createMockDoc() + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Local File Icon', () => { + it('should render FileTypeIcon for FILE data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + data_source_info: { + upload_file: { extension: 'pdf' }, + }, + }) + + const { container } = render() + const icon = container.querySelector('svg, img') + expect(icon).toBeInTheDocument() + }) + + it('should render FileTypeIcon for localFile data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.localFile, + created_from: 'rag-pipeline', + data_source_info: { + extension: 'docx', + }, + }) + + const { container } = render() + const icon = container.querySelector('svg, img') + expect(icon).toBeInTheDocument() + }) + + it('should use extension from upload_file for legacy data source', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + created_from: 'web', + data_source_info: { + upload_file: { extension: 'txt' }, + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should use fileType prop as fallback for extension', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + created_from: 'web', + data_source_info: {}, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Notion Icon', () => { + it('should render NotionIcon for NOTION data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.NOTION, + created_from: 'web', + data_source_info: { + notion_page_icon: 'https://notion.so/icon.png', + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should render NotionIcon for onlineDocument data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDocument, + created_from: 'rag-pipeline', + data_source_info: { + page: { page_icon: 'https://notion.so/icon.png' }, + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should use page_icon for rag-pipeline created documents', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.NOTION, + created_from: 'rag-pipeline', + data_source_info: { + page: { page_icon: 'https://notion.so/custom-icon.png' }, + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Web Crawl Icon', () => { + it('should render globe icon for WEB data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.WEB, + }) + + const { container } = render() + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + expect(icon).toHaveClass('mr-1.5') + expect(icon).toHaveClass('size-4') + }) + + it('should render globe icon for websiteCrawl data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.websiteCrawl, + }) + + const { container } = render() + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + }) + + describe('Online Drive Icon', () => { + it('should render FileTypeIcon for onlineDrive data source type', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: 'document.xlsx', + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should extract extension from file name', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: 'spreadsheet.xlsx', + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should handle file name without extension', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: 'noextension', + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should handle empty file name', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: '', + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should handle hidden files (starting with dot)', () => { + const doc = createMockDoc({ + data_source_type: DatasourceType.onlineDrive, + data_source_info: { + name: '.gitignore', + }, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + }) + + describe('Unknown Data Source Type', () => { + it('should return null for unknown data source type', () => { + const doc = createMockDoc({ + data_source_type: 'unknown', + }) + + const { container } = render() + expect(container.firstChild).toBeNull() + }) + }) + + describe('Edge Cases', () => { + it('should handle undefined data_source_info', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.FILE, + data_source_info: undefined, + }) + + const { container } = render() + expect(container.firstChild).toBeInTheDocument() + }) + + it('should memoize the component', () => { + const doc = createMockDoc() + const { rerender, container } = render() + + const firstRender = container.innerHTML + rerender() + expect(container.innerHTML).toBe(firstRender) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx new file mode 100644 index 0000000000..5461f34921 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx @@ -0,0 +1,100 @@ +import type { FC } from 'react' +import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, OnlineDriveInfo, SimpleDocumentDetail } from '@/models/datasets' +import { RiGlobalLine } from '@remixicon/react' +import * as React from 'react' +import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon' +import NotionIcon from '@/app/components/base/notion-icon' +import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type' +import { DataSourceType } from '@/models/datasets' +import { DatasourceType } from '@/models/pipeline' + +type DocumentSourceIconProps = { + doc: SimpleDocumentDetail + fileType?: string +} + +const isLocalFile = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.localFile || dataSourceType === DataSourceType.FILE +} + +const isOnlineDocument = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.onlineDocument || dataSourceType === DataSourceType.NOTION +} + +const isWebsiteCrawl = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.websiteCrawl || dataSourceType === DataSourceType.WEB +} + +const isOnlineDrive = (dataSourceType: DataSourceType | DatasourceType) => { + return dataSourceType === DatasourceType.onlineDrive +} + +const isCreateFromRAGPipeline = (createdFrom: string) => { + return createdFrom === 'rag-pipeline' +} + +const getFileExtension = (fileName: string): string => { + if (!fileName) + return '' + const parts = fileName.split('.') + if (parts.length <= 1 || (parts[0] === '' && parts.length === 2)) + return '' + return parts[parts.length - 1].toLowerCase() +} + +const DocumentSourceIcon: FC = React.memo(({ + doc, + fileType, +}) => { + if (isOnlineDocument(doc.data_source_type)) { + return ( + + ) + } + + if (isLocalFile(doc.data_source_type)) { + return ( + + ) + } + + if (isOnlineDrive(doc.data_source_type)) { + return ( + + ) + } + + if (isWebsiteCrawl(doc.data_source_type)) { + return + } + + return null +}) + +DocumentSourceIcon.displayName = 'DocumentSourceIcon' + +export default DocumentSourceIcon diff --git a/web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx new file mode 100644 index 0000000000..7157a9bf4b --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-table-row.spec.tsx @@ -0,0 +1,342 @@ +import type { ReactNode } from 'react' +import type { SimpleDocumentDetail } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { DataSourceType } from '@/models/datasets' +import DocumentTableRow from './document-table-row' + +const mockPush = vi.fn() + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + + + + {children} + +
+
+ ) +} + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +const createMockDoc = (overrides: Record = {}): LocalDoc => ({ + id: 'doc-1', + position: 1, + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: { + upload_file: { name: 'test.txt', extension: 'txt' }, + }, + dataset_process_rule_id: 'rule-1', + dataset_id: 'dataset-1', + batch: 'batch-1', + name: 'test-document.txt', + created_from: 'web', + created_by: 'user-1', + created_at: Date.now(), + tokens: 100, + indexing_status: 'completed', + error: null, + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: Date.now(), + doc_type: null, + doc_metadata: undefined, + doc_language: 'en', + display_status: 'available', + word_count: 500, + hit_count: 10, + doc_form: 'text_model', + ...overrides, +}) as unknown as LocalDoc + +// Helper to find the custom checkbox div (Checkbox component renders as a div, not a native checkbox) +const findCheckbox = (container: HTMLElement): HTMLElement | null => { + return container.querySelector('[class*="shadow-xs"]') +} + +describe('DocumentTableRow', () => { + const defaultProps = { + doc: createMockDoc(), + index: 0, + datasetId: 'dataset-1', + isSelected: false, + isGeneralMode: true, + isQAMode: false, + embeddingAvailable: true, + selectedIds: [], + onSelectOne: vi.fn(), + onSelectedIdChange: vi.fn(), + onShowRenameModal: vi.fn(), + onUpdate: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('test-document.txt')).toBeInTheDocument() + }) + + it('should render index number correctly', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('6')).toBeInTheDocument() + }) + + it('should render document name with tooltip', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('test-document.txt')).toBeInTheDocument() + }) + + it('should render checkbox element', () => { + const { container } = render(, { wrapper: createWrapper() }) + const checkbox = findCheckbox(container) + expect(checkbox).toBeInTheDocument() + }) + }) + + describe('Selection', () => { + it('should show check icon when isSelected is true', () => { + const { container } = render(, { wrapper: createWrapper() }) + // When selected, the checkbox should have a check icon (RiCheckLine svg) + const checkbox = findCheckbox(container) + expect(checkbox).toBeInTheDocument() + const checkIcon = checkbox?.querySelector('svg') + expect(checkIcon).toBeInTheDocument() + }) + + it('should not show check icon when isSelected is false', () => { + const { container } = render(, { wrapper: createWrapper() }) + const checkbox = findCheckbox(container) + expect(checkbox).toBeInTheDocument() + // When not selected, there should be no check icon inside the checkbox + const checkIcon = checkbox?.querySelector('svg') + expect(checkIcon).not.toBeInTheDocument() + }) + + it('should call onSelectOne when checkbox is clicked', () => { + const onSelectOne = vi.fn() + const { container } = render(, { wrapper: createWrapper() }) + + const checkbox = findCheckbox(container) + if (checkbox) { + fireEvent.click(checkbox) + expect(onSelectOne).toHaveBeenCalledWith('doc-1') + } + }) + + it('should stop propagation when checkbox container is clicked', () => { + const { container } = render(, { wrapper: createWrapper() }) + + // Click the div containing the checkbox (which has stopPropagation) + const checkboxContainer = container.querySelector('td')?.querySelector('div') + if (checkboxContainer) { + fireEvent.click(checkboxContainer) + expect(mockPush).not.toHaveBeenCalled() + } + }) + }) + + describe('Row Navigation', () => { + it('should navigate to document detail on row click', () => { + render(, { wrapper: createWrapper() }) + + const row = screen.getByRole('row') + fireEvent.click(row) + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1') + }) + + it('should navigate with correct datasetId and documentId', () => { + render( + , + { wrapper: createWrapper() }, + ) + + const row = screen.getByRole('row') + fireEvent.click(row) + + expect(mockPush).toHaveBeenCalledWith('/datasets/custom-dataset/documents/custom-doc') + }) + }) + + describe('Word Count Display', () => { + it('should display word count less than 1000 as is', () => { + const doc = createMockDoc({ word_count: 500 }) + render(, { wrapper: createWrapper() }) + expect(screen.getByText('500')).toBeInTheDocument() + }) + + it('should display word count 1000 or more in k format', () => { + const doc = createMockDoc({ word_count: 1500 }) + render(, { wrapper: createWrapper() }) + expect(screen.getByText('1.5k')).toBeInTheDocument() + }) + + it('should display 0 with empty style when word_count is 0', () => { + const doc = createMockDoc({ word_count: 0 }) + const { container } = render(, { wrapper: createWrapper() }) + const zeroCells = container.querySelectorAll('.text-text-tertiary') + expect(zeroCells.length).toBeGreaterThan(0) + }) + + it('should handle undefined word_count', () => { + const doc = createMockDoc({ word_count: undefined as unknown as number }) + const { container } = render(, { wrapper: createWrapper() }) + expect(container).toBeInTheDocument() + }) + }) + + describe('Hit Count Display', () => { + it('should display hit count less than 1000 as is', () => { + const doc = createMockDoc({ hit_count: 100 }) + render(, { wrapper: createWrapper() }) + expect(screen.getByText('100')).toBeInTheDocument() + }) + + it('should display hit count 1000 or more in k format', () => { + const doc = createMockDoc({ hit_count: 2500 }) + render(, { wrapper: createWrapper() }) + expect(screen.getByText('2.5k')).toBeInTheDocument() + }) + + it('should display 0 with empty style when hit_count is 0', () => { + const doc = createMockDoc({ hit_count: 0 }) + const { container } = render(, { wrapper: createWrapper() }) + const zeroCells = container.querySelectorAll('.text-text-tertiary') + expect(zeroCells.length).toBeGreaterThan(0) + }) + }) + + describe('Chunking Mode', () => { + it('should render ChunkingModeLabel with general mode', () => { + render(, { wrapper: createWrapper() }) + // ChunkingModeLabel should be rendered + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should render ChunkingModeLabel with QA mode', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Summary Status', () => { + it('should render SummaryStatus when summary_index_status is present', () => { + const doc = createMockDoc({ summary_index_status: 'completed' }) + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should not render SummaryStatus when summary_index_status is absent', () => { + const doc = createMockDoc({ summary_index_status: undefined }) + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Rename Action', () => { + it('should call onShowRenameModal when rename button is clicked', () => { + const onShowRenameModal = vi.fn() + const { container } = render( + , + { wrapper: createWrapper() }, + ) + + // Find the rename button by finding the RiEditLine icon's parent + const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') + if (renameButtons.length > 0) { + fireEvent.click(renameButtons[0]) + expect(onShowRenameModal).toHaveBeenCalledWith(defaultProps.doc) + expect(mockPush).not.toHaveBeenCalled() + } + }) + }) + + describe('Operations', () => { + it('should pass selectedIds to Operations component', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should pass onSelectedIdChange to Operations component', () => { + const onSelectedIdChange = vi.fn() + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Document Source Icon', () => { + it('should render with FILE data source type', () => { + const doc = createMockDoc({ data_source_type: DataSourceType.FILE }) + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should render with NOTION data source type', () => { + const doc = createMockDoc({ + data_source_type: DataSourceType.NOTION, + data_source_info: { notion_page_icon: 'icon.png' }, + }) + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should render with WEB data source type', () => { + const doc = createMockDoc({ data_source_type: DataSourceType.WEB }) + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle document with very long name', () => { + const doc = createMockDoc({ name: `${'a'.repeat(500)}.txt` }) + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('row')).toBeInTheDocument() + }) + + it('should handle document with special characters in name', () => { + const doc = createMockDoc({ name: '.txt' }) + render(, { wrapper: createWrapper() }) + expect(screen.getByText('.txt')).toBeInTheDocument() + }) + + it('should memoize the component', () => { + const wrapper = createWrapper() + const { rerender } = render(, { wrapper }) + + rerender() + expect(screen.getByRole('row')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx b/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx new file mode 100644 index 0000000000..731c14e731 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/document-table-row.tsx @@ -0,0 +1,152 @@ +import type { FC } from 'react' +import type { SimpleDocumentDetail } from '@/models/datasets' +import { RiEditLine } from '@remixicon/react' +import { pick } from 'es-toolkit/object' +import { useRouter } from 'next/navigation' +import * as React from 'react' +import { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' +import Tooltip from '@/app/components/base/tooltip' +import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label' +import Operations from '@/app/components/datasets/documents/components/operations' +import SummaryStatus from '@/app/components/datasets/documents/detail/completed/common/summary-status' +import StatusItem from '@/app/components/datasets/documents/status-item' +import useTimestamp from '@/hooks/use-timestamp' +import { DataSourceType } from '@/models/datasets' +import { formatNumber } from '@/utils/format' +import DocumentSourceIcon from './document-source-icon' +import { renderTdValue } from './utils' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +type DocumentTableRowProps = { + doc: LocalDoc + index: number + datasetId: string + isSelected: boolean + isGeneralMode: boolean + isQAMode: boolean + embeddingAvailable: boolean + selectedIds: string[] + onSelectOne: (docId: string) => void + onSelectedIdChange: (ids: string[]) => void + onShowRenameModal: (doc: LocalDoc) => void + onUpdate: () => void +} + +const renderCount = (count: number | undefined) => { + if (!count) + return renderTdValue(0, true) + + if (count < 1000) + return count + + return `${formatNumber((count / 1000).toFixed(1))}k` +} + +const DocumentTableRow: FC = React.memo(({ + doc, + index, + datasetId, + isSelected, + isGeneralMode, + isQAMode, + embeddingAvailable, + selectedIds, + onSelectOne, + onSelectedIdChange, + onShowRenameModal, + onUpdate, +}) => { + const { t } = useTranslation() + const { formatTime } = useTimestamp() + const router = useRouter() + + const isFile = doc.data_source_type === DataSourceType.FILE + const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : '' + + const handleRowClick = useCallback(() => { + router.push(`/datasets/${datasetId}/documents/${doc.id}`) + }, [router, datasetId, doc.id]) + + const handleCheckboxClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + }, []) + + const handleRenameClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + onShowRenameModal(doc) + }, [doc, onShowRenameModal]) + + return ( + + +
+ onSelectOne(doc.id)} + /> + {index + 1} +
+ + +
+
+ +
+ + {doc.name} + + {doc.summary_index_status && ( +
+ +
+ )} +
+ +
+ +
+
+
+
+ + + + + {renderCount(doc.word_count)} + {renderCount(doc.hit_count)} + + {formatTime(doc.created_at, t('dateTimeFormat', { ns: 'datasetHitTesting' }) as string)} + + + + + + + + + ) +}) + +DocumentTableRow.displayName = 'DocumentTableRow' + +export default DocumentTableRow diff --git a/web/app/components/datasets/documents/components/document-list/components/index.ts b/web/app/components/datasets/documents/components/document-list/components/index.ts new file mode 100644 index 0000000000..377f64a27f --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/index.ts @@ -0,0 +1,4 @@ +export { default as DocumentSourceIcon } from './document-source-icon' +export { default as DocumentTableRow } from './document-table-row' +export { default as SortHeader } from './sort-header' +export { renderTdValue } from './utils' diff --git a/web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx new file mode 100644 index 0000000000..15cc55247b --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/sort-header.spec.tsx @@ -0,0 +1,124 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import SortHeader from './sort-header' + +describe('SortHeader', () => { + const defaultProps = { + field: 'name' as const, + label: 'File Name', + currentSortField: null, + sortOrder: 'desc' as const, + onSort: vi.fn(), + } + + describe('rendering', () => { + it('should render the label', () => { + render() + expect(screen.getByText('File Name')).toBeInTheDocument() + }) + + it('should render the sort icon', () => { + const { container } = render() + const icon = container.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + }) + + describe('inactive state', () => { + it('should have disabled text color when not active', () => { + const { container } = render() + const icon = container.querySelector('svg') + expect(icon).toHaveClass('text-text-disabled') + }) + + it('should not be rotated when not active', () => { + const { container } = render() + const icon = container.querySelector('svg') + expect(icon).not.toHaveClass('rotate-180') + }) + }) + + describe('active state', () => { + it('should have tertiary text color when active', () => { + const { container } = render( + , + ) + const icon = container.querySelector('svg') + expect(icon).toHaveClass('text-text-tertiary') + }) + + it('should not be rotated when active and desc', () => { + const { container } = render( + , + ) + const icon = container.querySelector('svg') + expect(icon).not.toHaveClass('rotate-180') + }) + + it('should be rotated when active and asc', () => { + const { container } = render( + , + ) + const icon = container.querySelector('svg') + expect(icon).toHaveClass('rotate-180') + }) + }) + + describe('interaction', () => { + it('should call onSort when clicked', () => { + const onSort = vi.fn() + render() + + fireEvent.click(screen.getByText('File Name')) + + expect(onSort).toHaveBeenCalledWith('name') + }) + + it('should call onSort with correct field', () => { + const onSort = vi.fn() + render() + + fireEvent.click(screen.getByText('File Name')) + + expect(onSort).toHaveBeenCalledWith('word_count') + }) + }) + + describe('different fields', () => { + it('should work with word_count field', () => { + render( + , + ) + expect(screen.getByText('Words')).toBeInTheDocument() + }) + + it('should work with hit_count field', () => { + render( + , + ) + expect(screen.getByText('Hit Count')).toBeInTheDocument() + }) + + it('should work with created_at field', () => { + render( + , + ) + expect(screen.getByText('Upload Time')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx new file mode 100644 index 0000000000..1dc13df2b0 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/sort-header.tsx @@ -0,0 +1,44 @@ +import type { FC } from 'react' +import type { SortField, SortOrder } from '../hooks' +import { RiArrowDownLine } from '@remixicon/react' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type SortHeaderProps = { + field: Exclude + label: string + currentSortField: SortField + sortOrder: SortOrder + onSort: (field: SortField) => void +} + +const SortHeader: FC = React.memo(({ + field, + label, + currentSortField, + sortOrder, + onSort, +}) => { + const isActive = currentSortField === field + const isDesc = isActive && sortOrder === 'desc' + + return ( +
onSort(field)} + > + {label} + +
+ ) +}) + +SortHeader.displayName = 'SortHeader' + +export default SortHeader diff --git a/web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx b/web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx new file mode 100644 index 0000000000..7dc66d4d39 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/utils.spec.tsx @@ -0,0 +1,90 @@ +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { renderTdValue } from './utils' + +describe('renderTdValue', () => { + describe('Rendering', () => { + it('should render string value correctly', () => { + const { container } = render(<>{renderTdValue('test value')}) + expect(screen.getByText('test value')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + + it('should render number value correctly', () => { + const { container } = render(<>{renderTdValue(42)}) + expect(screen.getByText('42')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + + it('should render zero correctly', () => { + const { container } = render(<>{renderTdValue(0)}) + expect(screen.getByText('0')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + }) + + describe('Null and undefined handling', () => { + it('should render dash for null value', () => { + render(<>{renderTdValue(null)}) + expect(screen.getByText('-')).toBeInTheDocument() + }) + + it('should render dash for null value with empty style', () => { + const { container } = render(<>{renderTdValue(null, true)}) + expect(screen.getByText('-')).toBeInTheDocument() + expect(container.querySelector('div')).toHaveClass('text-text-tertiary') + }) + }) + + describe('Empty style', () => { + it('should apply text-text-tertiary class when isEmptyStyle is true', () => { + const { container } = render(<>{renderTdValue('value', true)}) + expect(container.querySelector('div')).toHaveClass('text-text-tertiary') + }) + + it('should apply text-text-secondary class when isEmptyStyle is false', () => { + const { container } = render(<>{renderTdValue('value', false)}) + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + + it('should apply text-text-secondary class when isEmptyStyle is not provided', () => { + const { container } = render(<>{renderTdValue('value')}) + expect(container.querySelector('div')).toHaveClass('text-text-secondary') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty string', () => { + render(<>{renderTdValue('')}) + // Empty string should still render but with no visible text + const div = document.querySelector('div') + expect(div).toBeInTheDocument() + }) + + it('should handle large numbers', () => { + render(<>{renderTdValue(1234567890)}) + expect(screen.getByText('1234567890')).toBeInTheDocument() + }) + + it('should handle negative numbers', () => { + render(<>{renderTdValue(-42)}) + expect(screen.getByText('-42')).toBeInTheDocument() + }) + + it('should handle special characters in string', () => { + render(<>{renderTdValue('')}) + expect(screen.getByText('')).toBeInTheDocument() + }) + + it('should handle unicode characters', () => { + render(<>{renderTdValue('Test Unicode: \u4E2D\u6587')}) + expect(screen.getByText('Test Unicode: \u4E2D\u6587')).toBeInTheDocument() + }) + + it('should handle very long strings', () => { + const longString = 'a'.repeat(1000) + render(<>{renderTdValue(longString)}) + expect(screen.getByText(longString)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/components/utils.tsx b/web/app/components/datasets/documents/components/document-list/components/utils.tsx new file mode 100644 index 0000000000..4cb652108d --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/components/utils.tsx @@ -0,0 +1,16 @@ +import type { ReactNode } from 'react' +import { cn } from '@/utils/classnames' +import s from '../../../style.module.css' + +export const renderTdValue = (value: string | number | null, isEmptyStyle = false): ReactNode => { + const className = cn( + isEmptyStyle ? 'text-text-tertiary' : 'text-text-secondary', + s.tdValue, + ) + + return ( +
+ {value ?? '-'} +
+ ) +} diff --git a/web/app/components/datasets/documents/components/document-list/hooks/index.ts b/web/app/components/datasets/documents/components/document-list/hooks/index.ts new file mode 100644 index 0000000000..3ca7a920f2 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/index.ts @@ -0,0 +1,4 @@ +export { useDocumentActions } from './use-document-actions' +export { useDocumentSelection } from './use-document-selection' +export { useDocumentSort } from './use-document-sort' +export type { SortField, SortOrder } from './use-document-sort' diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx new file mode 100644 index 0000000000..bc84477744 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.spec.tsx @@ -0,0 +1,438 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { DocumentActionType } from '@/models/datasets' +import * as useDocument from '@/service/knowledge/use-document' +import { useDocumentActions } from './use-document-actions' + +vi.mock('@/service/knowledge/use-document') + +const mockUseDocumentArchive = vi.mocked(useDocument.useDocumentArchive) +const mockUseDocumentSummary = vi.mocked(useDocument.useDocumentSummary) +const mockUseDocumentEnable = vi.mocked(useDocument.useDocumentEnable) +const mockUseDocumentDisable = vi.mocked(useDocument.useDocumentDisable) +const mockUseDocumentDelete = vi.mocked(useDocument.useDocumentDelete) +const mockUseDocumentBatchRetryIndex = vi.mocked(useDocument.useDocumentBatchRetryIndex) +const mockUseDocumentDownloadZip = vi.mocked(useDocument.useDocumentDownloadZip) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + + {children} + + ) +} + +describe('useDocumentActions', () => { + const mockMutateAsync = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + // Setup all mocks with default values + const createMockMutation = () => ({ + mutateAsync: mockMutateAsync, + isPending: false, + isError: false, + isSuccess: false, + isIdle: true, + data: undefined, + error: null, + mutate: vi.fn(), + reset: vi.fn(), + status: 'idle' as const, + variables: undefined, + context: undefined, + failureCount: 0, + failureReason: null, + submittedAt: 0, + }) + + mockUseDocumentArchive.mockReturnValue(createMockMutation() as unknown as ReturnType) + mockUseDocumentSummary.mockReturnValue(createMockMutation() as unknown as ReturnType) + mockUseDocumentEnable.mockReturnValue(createMockMutation() as unknown as ReturnType) + mockUseDocumentDisable.mockReturnValue(createMockMutation() as unknown as ReturnType) + mockUseDocumentDelete.mockReturnValue(createMockMutation() as unknown as ReturnType) + mockUseDocumentBatchRetryIndex.mockReturnValue(createMockMutation() as unknown as ReturnType) + mockUseDocumentDownloadZip.mockReturnValue({ + ...createMockMutation(), + isPending: false, + } as unknown as ReturnType) + }) + + describe('handleAction', () => { + it('should call archive mutation when archive action is triggered', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.archive)() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentIds: ['doc1'], + }) + }) + + it('should call onUpdate on successful action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.enable)() + }) + + await waitFor(() => { + expect(onUpdate).toHaveBeenCalled() + }) + }) + + it('should call onClearSelection on delete action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.delete)() + }) + + await waitFor(() => { + expect(onClearSelection).toHaveBeenCalled() + }) + }) + }) + + describe('handleBatchReIndex', () => { + it('should call retry index mutation', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1', 'doc2'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchReIndex() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentIds: ['doc1', 'doc2'], + }) + }) + + it('should call onClearSelection on success', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchReIndex() + }) + + await waitFor(() => { + expect(onClearSelection).toHaveBeenCalled() + expect(onUpdate).toHaveBeenCalled() + }) + }) + }) + + describe('handleBatchDownload', () => { + it('should not proceed when already downloading', async () => { + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: true, + } as unknown as ReturnType) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + expect(mockMutateAsync).not.toHaveBeenCalled() + }) + + it('should call download mutation with downloadable ids', async () => { + const mockBlob = new Blob(['test']) + mockMutateAsync.mockResolvedValue(mockBlob) + + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: false, + } as unknown as ReturnType) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1', 'doc2'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentIds: ['doc1'], + }) + }) + }) + + describe('isDownloadingZip', () => { + it('should reflect isPending state from mutation', () => { + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: true, + } as unknown as ReturnType) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: [], + downloadableSelectedIds: [], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + expect(result.current.isDownloadingZip).toBe(true) + }) + }) + + describe('error handling', () => { + it('should show error toast when handleAction fails', async () => { + mockMutateAsync.mockRejectedValue(new Error('Action failed')) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.archive)() + }) + + // onUpdate should not be called on error + expect(onUpdate).not.toHaveBeenCalled() + }) + + it('should show error toast when handleBatchReIndex fails', async () => { + mockMutateAsync.mockRejectedValue(new Error('Re-index failed')) + const onUpdate = vi.fn() + const onClearSelection = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection, + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchReIndex() + }) + + // onUpdate and onClearSelection should not be called on error + expect(onUpdate).not.toHaveBeenCalled() + expect(onClearSelection).not.toHaveBeenCalled() + }) + + it('should show error toast when handleBatchDownload fails', async () => { + mockMutateAsync.mockRejectedValue(new Error('Download failed')) + + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: false, + } as unknown as ReturnType) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + // Mutation was called but failed + expect(mockMutateAsync).toHaveBeenCalled() + }) + + it('should show error toast when handleBatchDownload returns null blob', async () => { + mockMutateAsync.mockResolvedValue(null) + + mockUseDocumentDownloadZip.mockReturnValue({ + mutateAsync: mockMutateAsync, + isPending: false, + } as unknown as ReturnType) + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: ['doc1'], + onUpdate: vi.fn(), + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleBatchDownload() + }) + + // Mutation was called but returned null + expect(mockMutateAsync).toHaveBeenCalled() + }) + }) + + describe('all action types', () => { + it('should handle summary action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.summary)() + }) + + expect(mockMutateAsync).toHaveBeenCalled() + await waitFor(() => { + expect(onUpdate).toHaveBeenCalled() + }) + }) + + it('should handle disable action', async () => { + mockMutateAsync.mockResolvedValue({ result: 'success' }) + const onUpdate = vi.fn() + + const { result } = renderHook( + () => useDocumentActions({ + datasetId: 'ds1', + selectedIds: ['doc1'], + downloadableSelectedIds: [], + onUpdate, + onClearSelection: vi.fn(), + }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleAction(DocumentActionType.disable)() + }) + + expect(mockMutateAsync).toHaveBeenCalled() + await waitFor(() => { + expect(onUpdate).toHaveBeenCalled() + }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts new file mode 100644 index 0000000000..56553faa9e --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts @@ -0,0 +1,126 @@ +import type { CommonResponse } from '@/models/common' +import { useCallback, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { DocumentActionType } from '@/models/datasets' +import { + useDocumentArchive, + useDocumentBatchRetryIndex, + useDocumentDelete, + useDocumentDisable, + useDocumentDownloadZip, + useDocumentEnable, + useDocumentSummary, +} from '@/service/knowledge/use-document' +import { asyncRunSafe } from '@/utils' +import { downloadBlob } from '@/utils/download' + +type UseDocumentActionsOptions = { + datasetId: string + selectedIds: string[] + downloadableSelectedIds: string[] + onUpdate: () => void + onClearSelection: () => void +} + +/** + * Generate a random ZIP filename for bulk document downloads. + * We intentionally avoid leaking dataset info in the exported archive name. + */ +const generateDocsZipFileName = (): string => { + const randomPart = (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') + ? crypto.randomUUID() + : `${Date.now().toString(36)}${Math.random().toString(36).slice(2, 10)}` + return `${randomPart}-docs.zip` +} + +export const useDocumentActions = ({ + datasetId, + selectedIds, + downloadableSelectedIds, + onUpdate, + onClearSelection, +}: UseDocumentActionsOptions) => { + const { t } = useTranslation() + + const { mutateAsync: archiveDocument } = useDocumentArchive() + const { mutateAsync: generateSummary } = useDocumentSummary() + const { mutateAsync: enableDocument } = useDocumentEnable() + const { mutateAsync: disableDocument } = useDocumentDisable() + const { mutateAsync: deleteDocument } = useDocumentDelete() + const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex() + const { mutateAsync: requestDocumentsZip, isPending: isDownloadingZip } = useDocumentDownloadZip() + + type SupportedActionType + = | typeof DocumentActionType.archive + | typeof DocumentActionType.summary + | typeof DocumentActionType.enable + | typeof DocumentActionType.disable + | typeof DocumentActionType.delete + + const actionMutationMap = useMemo(() => ({ + [DocumentActionType.archive]: archiveDocument, + [DocumentActionType.summary]: generateSummary, + [DocumentActionType.enable]: enableDocument, + [DocumentActionType.disable]: disableDocument, + [DocumentActionType.delete]: deleteDocument, + } as const), [archiveDocument, generateSummary, enableDocument, disableDocument, deleteDocument]) + + const handleAction = useCallback((actionName: SupportedActionType) => { + return async () => { + const opApi = actionMutationMap[actionName] + if (!opApi) + return + + const [e] = await asyncRunSafe( + opApi({ datasetId, documentIds: selectedIds }), + ) + + if (!e) { + if (actionName === DocumentActionType.delete) + onClearSelection() + Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + onUpdate() + } + else { + Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + } + } + }, [actionMutationMap, datasetId, selectedIds, onClearSelection, onUpdate, t]) + + const handleBatchReIndex = useCallback(async () => { + const [e] = await asyncRunSafe( + retryIndexDocument({ datasetId, documentIds: selectedIds }), + ) + if (!e) { + onClearSelection() + Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + onUpdate() + } + else { + Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + } + }, [retryIndexDocument, datasetId, selectedIds, onClearSelection, onUpdate, t]) + + const handleBatchDownload = useCallback(async () => { + if (isDownloadingZip) + return + + const [e, blob] = await asyncRunSafe( + requestDocumentsZip({ datasetId, documentIds: downloadableSelectedIds }), + ) + if (e || !blob) { + Toast.notify({ type: 'error', message: t('actionMsg.downloadUnsuccessfully', { ns: 'common' }) }) + return + } + + downloadBlob({ data: blob, fileName: generateDocsZipFileName() }) + }, [datasetId, downloadableSelectedIds, isDownloadingZip, requestDocumentsZip, t]) + + return { + handleAction, + handleBatchReIndex, + handleBatchDownload, + isDownloadingZip, + } +} diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts new file mode 100644 index 0000000000..7775c83f1c --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.spec.ts @@ -0,0 +1,317 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { act, renderHook } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { DataSourceType } from '@/models/datasets' +import { useDocumentSelection } from './use-document-selection' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +const createMockDocument = (overrides: Partial = {}): LocalDoc => ({ + id: 'doc1', + name: 'Test Document', + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: {}, + word_count: 100, + hit_count: 10, + created_at: 1000000, + position: 1, + doc_form: 'text_model', + enabled: true, + archived: false, + display_status: 'available', + created_from: 'api', + ...overrides, +} as LocalDoc) + +describe('useDocumentSelection', () => { + describe('isAllSelected', () => { + it('should return false when documents is empty', () => { + const onSelectedIdChange = vi.fn() + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: [], + onSelectedIdChange, + }), + ) + + expect(result.current.isAllSelected).toBe(false) + }) + + it('should return true when all documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + expect(result.current.isAllSelected).toBe(true) + }) + + it('should return false when not all documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + expect(result.current.isAllSelected).toBe(false) + }) + }) + + describe('isSomeSelected', () => { + it('should return false when no documents are selected', () => { + const docs = [createMockDocument({ id: 'doc1' })] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: [], + onSelectedIdChange, + }), + ) + + expect(result.current.isSomeSelected).toBe(false) + }) + + it('should return true when some documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + expect(result.current.isSomeSelected).toBe(true) + }) + }) + + describe('onSelectAll', () => { + it('should select all documents when none are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: [], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectAll() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1', 'doc2']) + }) + + it('should deselect all when all are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectAll() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith([]) + }) + + it('should add to existing selection when some are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1' }), + createMockDocument({ id: 'doc2' }), + createMockDocument({ id: 'doc3' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectAll() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1', 'doc2', 'doc3']) + }) + }) + + describe('onSelectOne', () => { + it('should add document to selection when not selected', () => { + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: [], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectOne('doc1') + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc1']) + }) + + it('should remove document from selection when already selected', () => { + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.onSelectOne('doc1') + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith(['doc2']) + }) + }) + + describe('hasErrorDocumentsSelected', () => { + it('should return false when no error documents are selected', () => { + const docs = [ + createMockDocument({ id: 'doc1', display_status: 'available' }), + createMockDocument({ id: 'doc2', display_status: 'error' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1'], + onSelectedIdChange, + }), + ) + + expect(result.current.hasErrorDocumentsSelected).toBe(false) + }) + + it('should return true when an error document is selected', () => { + const docs = [ + createMockDocument({ id: 'doc1', display_status: 'available' }), + createMockDocument({ id: 'doc2', display_status: 'error' }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc2'], + onSelectedIdChange, + }), + ) + + expect(result.current.hasErrorDocumentsSelected).toBe(true) + }) + }) + + describe('downloadableSelectedIds', () => { + it('should return only FILE type documents from selection', () => { + const docs = [ + createMockDocument({ id: 'doc1', data_source_type: DataSourceType.FILE }), + createMockDocument({ id: 'doc2', data_source_type: DataSourceType.NOTION }), + createMockDocument({ id: 'doc3', data_source_type: DataSourceType.FILE }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2', 'doc3'], + onSelectedIdChange, + }), + ) + + expect(result.current.downloadableSelectedIds).toEqual(['doc1', 'doc3']) + }) + + it('should return empty array when no FILE documents selected', () => { + const docs = [ + createMockDocument({ id: 'doc1', data_source_type: DataSourceType.NOTION }), + createMockDocument({ id: 'doc2', data_source_type: DataSourceType.WEB }), + ] + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: docs, + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + expect(result.current.downloadableSelectedIds).toEqual([]) + }) + }) + + describe('clearSelection', () => { + it('should call onSelectedIdChange with empty array', () => { + const onSelectedIdChange = vi.fn() + + const { result } = renderHook(() => + useDocumentSelection({ + documents: [], + selectedIds: ['doc1', 'doc2'], + onSelectedIdChange, + }), + ) + + act(() => { + result.current.clearSelection() + }) + + expect(onSelectedIdChange).toHaveBeenCalledWith([]) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts new file mode 100644 index 0000000000..ad12b2b00f --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-selection.ts @@ -0,0 +1,66 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { uniq } from 'es-toolkit/array' +import { useCallback, useMemo } from 'react' +import { DataSourceType } from '@/models/datasets' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +type UseDocumentSelectionOptions = { + documents: LocalDoc[] + selectedIds: string[] + onSelectedIdChange: (selectedIds: string[]) => void +} + +export const useDocumentSelection = ({ + documents, + selectedIds, + onSelectedIdChange, +}: UseDocumentSelectionOptions) => { + const isAllSelected = useMemo(() => { + return documents.length > 0 && documents.every(doc => selectedIds.includes(doc.id)) + }, [documents, selectedIds]) + + const isSomeSelected = useMemo(() => { + return documents.some(doc => selectedIds.includes(doc.id)) + }, [documents, selectedIds]) + + const onSelectAll = useCallback(() => { + if (isAllSelected) + onSelectedIdChange([]) + else + onSelectedIdChange(uniq([...selectedIds, ...documents.map(doc => doc.id)])) + }, [isAllSelected, documents, onSelectedIdChange, selectedIds]) + + const onSelectOne = useCallback((docId: string) => { + onSelectedIdChange( + selectedIds.includes(docId) + ? selectedIds.filter(id => id !== docId) + : [...selectedIds, docId], + ) + }, [selectedIds, onSelectedIdChange]) + + const hasErrorDocumentsSelected = useMemo(() => { + return documents.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error') + }, [documents, selectedIds]) + + const downloadableSelectedIds = useMemo(() => { + const selectedSet = new Set(selectedIds) + return documents + .filter(doc => selectedSet.has(doc.id) && doc.data_source_type === DataSourceType.FILE) + .map(doc => doc.id) + }, [documents, selectedIds]) + + const clearSelection = useCallback(() => { + onSelectedIdChange([]) + }, [onSelectedIdChange]) + + return { + isAllSelected, + isSomeSelected, + onSelectAll, + onSelectOne, + hasErrorDocumentsSelected, + downloadableSelectedIds, + clearSelection, + } +} diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts new file mode 100644 index 0000000000..a41b42d6fa --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.spec.ts @@ -0,0 +1,340 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { act, renderHook } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import { useDocumentSort } from './use-document-sort' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +const createMockDocument = (overrides: Partial = {}): LocalDoc => ({ + id: 'doc1', + name: 'Test Document', + data_source_type: 'upload_file', + data_source_info: {}, + data_source_detail_dict: {}, + word_count: 100, + hit_count: 10, + created_at: 1000000, + position: 1, + doc_form: 'text_model', + enabled: true, + archived: false, + display_status: 'available', + created_from: 'api', + ...overrides, +} as LocalDoc) + +describe('useDocumentSort', () => { + describe('initial state', () => { + it('should return null sortField initially', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + expect(result.current.sortField).toBeNull() + expect(result.current.sortOrder).toBe('desc') + }) + + it('should return documents unchanged when no sort is applied', () => { + const docs = [ + createMockDocument({ id: 'doc1', name: 'B' }), + createMockDocument({ id: 'doc2', name: 'A' }), + ] + + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + expect(result.current.sortedDocuments).toEqual(docs) + }) + }) + + describe('handleSort', () => { + it('should set sort field when called', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + expect(result.current.sortField).toBe('name') + expect(result.current.sortOrder).toBe('desc') + }) + + it('should toggle sort order when same field is clicked twice', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('desc') + + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('asc') + + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('desc') + }) + + it('should reset to desc when different field is selected', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortOrder).toBe('asc') + + act(() => { + result.current.handleSort('word_count') + }) + expect(result.current.sortField).toBe('word_count') + expect(result.current.sortOrder).toBe('desc') + }) + + it('should not change state when null is passed', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort(null) + }) + + expect(result.current.sortField).toBeNull() + }) + }) + + describe('sorting documents', () => { + const docs = [ + createMockDocument({ id: 'doc1', name: 'Banana', word_count: 200, hit_count: 5, created_at: 3000 }), + createMockDocument({ id: 'doc2', name: 'Apple', word_count: 100, hit_count: 10, created_at: 1000 }), + createMockDocument({ id: 'doc3', name: 'Cherry', word_count: 300, hit_count: 1, created_at: 2000 }), + ] + + it('should sort by name descending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + const names = result.current.sortedDocuments.map(d => d.name) + expect(names).toEqual(['Cherry', 'Banana', 'Apple']) + }) + + it('should sort by name ascending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + act(() => { + result.current.handleSort('name') + }) + + const names = result.current.sortedDocuments.map(d => d.name) + expect(names).toEqual(['Apple', 'Banana', 'Cherry']) + }) + + it('should sort by word_count descending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('word_count') + }) + + const counts = result.current.sortedDocuments.map(d => d.word_count) + expect(counts).toEqual([300, 200, 100]) + }) + + it('should sort by hit_count ascending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('hit_count') + }) + act(() => { + result.current.handleSort('hit_count') + }) + + const counts = result.current.sortedDocuments.map(d => d.hit_count) + expect(counts).toEqual([1, 5, 10]) + }) + + it('should sort by created_at descending', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('created_at') + }) + + const times = result.current.sortedDocuments.map(d => d.created_at) + expect(times).toEqual([3000, 2000, 1000]) + }) + }) + + describe('status filtering', () => { + const docs = [ + createMockDocument({ id: 'doc1', display_status: 'available' }), + createMockDocument({ id: 'doc2', display_status: 'error' }), + createMockDocument({ id: 'doc3', display_status: 'available' }), + ] + + it('should not filter when statusFilterValue is empty', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + expect(result.current.sortedDocuments.length).toBe(3) + }) + + it('should not filter when statusFilterValue is all', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: 'all', + remoteSortValue: '', + }), + ) + + expect(result.current.sortedDocuments.length).toBe(3) + }) + }) + + describe('remoteSortValue reset', () => { + it('should reset sort state when remoteSortValue changes', () => { + const { result, rerender } = renderHook( + ({ remoteSortValue }) => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue, + }), + { initialProps: { remoteSortValue: 'initial' } }, + ) + + act(() => { + result.current.handleSort('name') + }) + act(() => { + result.current.handleSort('name') + }) + expect(result.current.sortField).toBe('name') + expect(result.current.sortOrder).toBe('asc') + + rerender({ remoteSortValue: 'changed' }) + + expect(result.current.sortField).toBeNull() + expect(result.current.sortOrder).toBe('desc') + }) + }) + + describe('edge cases', () => { + it('should handle documents with missing values', () => { + const docs = [ + createMockDocument({ id: 'doc1', name: undefined as unknown as string, word_count: undefined }), + createMockDocument({ id: 'doc2', name: 'Test', word_count: 100 }), + ] + + const { result } = renderHook(() => + useDocumentSort({ + documents: docs, + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + expect(result.current.sortedDocuments.length).toBe(2) + }) + + it('should handle empty documents array', () => { + const { result } = renderHook(() => + useDocumentSort({ + documents: [], + statusFilterValue: '', + remoteSortValue: '', + }), + ) + + act(() => { + result.current.handleSort('name') + }) + + expect(result.current.sortedDocuments).toEqual([]) + }) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts new file mode 100644 index 0000000000..98cf244f36 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-sort.ts @@ -0,0 +1,102 @@ +import type { SimpleDocumentDetail } from '@/models/datasets' +import { useCallback, useMemo, useRef, useState } from 'react' +import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter' + +export type SortField = 'name' | 'word_count' | 'hit_count' | 'created_at' | null +export type SortOrder = 'asc' | 'desc' + +type LocalDoc = SimpleDocumentDetail & { percent?: number } + +type UseDocumentSortOptions = { + documents: LocalDoc[] + statusFilterValue: string + remoteSortValue: string +} + +export const useDocumentSort = ({ + documents, + statusFilterValue, + remoteSortValue, +}: UseDocumentSortOptions) => { + const [sortField, setSortField] = useState(null) + const [sortOrder, setSortOrder] = useState('desc') + const prevRemoteSortValueRef = useRef(remoteSortValue) + + // Reset sort when remote sort changes + if (prevRemoteSortValueRef.current !== remoteSortValue) { + prevRemoteSortValueRef.current = remoteSortValue + setSortField(null) + setSortOrder('desc') + } + + const handleSort = useCallback((field: SortField) => { + if (field === null) + return + + if (sortField === field) { + setSortOrder(prev => prev === 'asc' ? 'desc' : 'asc') + } + else { + setSortField(field) + setSortOrder('desc') + } + }, [sortField]) + + const sortedDocuments = useMemo(() => { + let filteredDocs = documents + + if (statusFilterValue && statusFilterValue !== 'all') { + filteredDocs = filteredDocs.filter(doc => + typeof doc.display_status === 'string' + && normalizeStatusForQuery(doc.display_status) === statusFilterValue, + ) + } + + if (!sortField) + return filteredDocs + + const sortedDocs = [...filteredDocs].sort((a, b) => { + let aValue: string | number + let bValue: string | number + + switch (sortField) { + case 'name': + aValue = a.name?.toLowerCase() || '' + bValue = b.name?.toLowerCase() || '' + break + case 'word_count': + aValue = a.word_count || 0 + bValue = b.word_count || 0 + break + case 'hit_count': + aValue = a.hit_count || 0 + bValue = b.hit_count || 0 + break + case 'created_at': + aValue = a.created_at + bValue = b.created_at + break + default: + return 0 + } + + if (sortField === 'name') { + const result = (aValue as string).localeCompare(bValue as string) + return sortOrder === 'asc' ? result : -result + } + else { + const result = (aValue as number) - (bValue as number) + return sortOrder === 'asc' ? result : -result + } + }) + + return sortedDocs + }, [documents, sortField, sortOrder, statusFilterValue]) + + return { + sortField, + sortOrder, + handleSort, + sortedDocuments, + } +} diff --git a/web/app/components/datasets/documents/components/document-list/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/index.spec.tsx new file mode 100644 index 0000000000..32429cc0ac --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/index.spec.tsx @@ -0,0 +1,487 @@ +import type { ReactNode } from 'react' +import type { Props as PaginationProps } from '@/app/components/base/pagination' +import type { SimpleDocumentDetail } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ChunkingMode, DataSourceType } from '@/models/datasets' +import DocumentList from '../list' + +const mockPush = vi.fn() + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: { doc_form: string } }) => unknown) => + selector({ dataset: { doc_form: ChunkingMode.text } }), +})) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + + {children} + + ) +} + +const createMockDoc = (overrides: Partial = {}): SimpleDocumentDetail => ({ + id: `doc-${Math.random().toString(36).substr(2, 9)}`, + position: 1, + data_source_type: DataSourceType.FILE, + data_source_info: {}, + data_source_detail_dict: { + upload_file: { name: 'test.txt', extension: 'txt' }, + }, + dataset_process_rule_id: 'rule-1', + batch: 'batch-1', + name: 'test-document.txt', + created_from: 'web', + created_by: 'user-1', + created_at: Date.now(), + tokens: 100, + indexing_status: 'completed', + error: null, + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: Date.now(), + doc_type: null, + doc_metadata: undefined, + display_status: 'available', + word_count: 500, + hit_count: 10, + doc_form: 'text_model', + ...overrides, +} as SimpleDocumentDetail) + +const defaultPagination: PaginationProps = { + current: 1, + onChange: vi.fn(), + total: 100, +} + +describe('DocumentList', () => { + const defaultProps = { + embeddingAvailable: true, + documents: [ + createMockDoc({ id: 'doc-1', name: 'Document 1.txt', word_count: 100, hit_count: 5 }), + createMockDoc({ id: 'doc-2', name: 'Document 2.txt', word_count: 200, hit_count: 10 }), + createMockDoc({ id: 'doc-3', name: 'Document 3.txt', word_count: 300, hit_count: 15 }), + ], + selectedIds: [] as string[], + onSelectedIdChange: vi.fn(), + datasetId: 'dataset-1', + pagination: defaultPagination, + onUpdate: vi.fn(), + onManageMetadata: vi.fn(), + statusFilterValue: '', + remoteSortValue: '', + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render all documents', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('Document 1.txt')).toBeInTheDocument() + expect(screen.getByText('Document 2.txt')).toBeInTheDocument() + expect(screen.getByText('Document 3.txt')).toBeInTheDocument() + }) + + it('should render table headers', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('#')).toBeInTheDocument() + }) + + it('should render pagination when total is provided', () => { + render(, { wrapper: createWrapper() }) + // Pagination component should be present + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should not render pagination when total is 0', () => { + const props = { + ...defaultProps, + pagination: { ...defaultPagination, total: 0 }, + } + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render empty table when no documents', () => { + const props = { ...defaultProps, documents: [] } + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Selection', () => { + // Helper to find checkboxes (custom div components, not native checkboxes) + const findCheckboxes = (container: HTMLElement): NodeListOf => { + return container.querySelectorAll('[class*="shadow-xs"]') + } + + it('should render header checkbox when embeddingAvailable', () => { + const { container } = render(, { wrapper: createWrapper() }) + const checkboxes = findCheckboxes(container) + expect(checkboxes.length).toBeGreaterThan(0) + }) + + it('should not render header checkbox when embedding not available', () => { + const props = { ...defaultProps, embeddingAvailable: false } + render(, { wrapper: createWrapper() }) + // Row checkboxes should still be there, but header checkbox should be hidden + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should call onSelectedIdChange when select all is clicked', () => { + const onSelectedIdChange = vi.fn() + const props = { ...defaultProps, onSelectedIdChange } + const { container } = render(, { wrapper: createWrapper() }) + + const checkboxes = findCheckboxes(container) + if (checkboxes.length > 0) { + fireEvent.click(checkboxes[0]) + expect(onSelectedIdChange).toHaveBeenCalled() + } + }) + + it('should show all checkboxes as checked when all are selected', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1', 'doc-2', 'doc-3'], + } + const { container } = render(, { wrapper: createWrapper() }) + + const checkboxes = findCheckboxes(container) + // When checked, checkbox should have a check icon (svg) inside + checkboxes.forEach((checkbox) => { + const checkIcon = checkbox.querySelector('svg') + expect(checkIcon).toBeInTheDocument() + }) + }) + + it('should show indeterminate state when some are selected', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + const { container } = render(, { wrapper: createWrapper() }) + + // First checkbox is the header checkbox which should be indeterminate + const checkboxes = findCheckboxes(container) + expect(checkboxes.length).toBeGreaterThan(0) + // Header checkbox should show indeterminate icon, not check icon + // Just verify it's rendered + expect(checkboxes[0]).toBeInTheDocument() + }) + + it('should call onSelectedIdChange with single document when row checkbox is clicked', () => { + const onSelectedIdChange = vi.fn() + const props = { ...defaultProps, onSelectedIdChange } + const { container } = render(, { wrapper: createWrapper() }) + + // Click the second checkbox (first row checkbox) + const checkboxes = findCheckboxes(container) + if (checkboxes.length > 1) { + fireEvent.click(checkboxes[1]) + expect(onSelectedIdChange).toHaveBeenCalled() + } + }) + }) + + describe('Sorting', () => { + it('should render sort headers for sortable columns', () => { + render(, { wrapper: createWrapper() }) + // Find svg icons which indicate sortable columns + const sortIcons = document.querySelectorAll('svg') + expect(sortIcons.length).toBeGreaterThan(0) + }) + + it('should update sort order when sort header is clicked', () => { + render(, { wrapper: createWrapper() }) + + // Find and click a sort header by its parent div containing the label text + const sortableHeaders = document.querySelectorAll('[class*="cursor-pointer"]') + if (sortableHeaders.length > 0) { + fireEvent.click(sortableHeaders[0]) + } + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Batch Actions', () => { + it('should show batch action bar when documents are selected', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1', 'doc-2'], + } + render(, { wrapper: createWrapper() }) + + // BatchAction component should be visible + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should not show batch action bar when no documents selected', () => { + render(, { wrapper: createWrapper() }) + + // BatchAction should not be present + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with archive option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(, { wrapper: createWrapper() }) + + // BatchAction component should be visible when documents are selected + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with enable option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with disable option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render batch action bar with delete option', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should clear selection when cancel is clicked', () => { + const onSelectedIdChange = vi.fn() + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + onSelectedIdChange, + } + render(, { wrapper: createWrapper() }) + + const cancelButton = screen.queryByRole('button', { name: /cancel/i }) + if (cancelButton) { + fireEvent.click(cancelButton) + expect(onSelectedIdChange).toHaveBeenCalledWith([]) + } + }) + + it('should show download option for downloadable documents', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + documents: [ + createMockDoc({ id: 'doc-1', data_source_type: DataSourceType.FILE }), + ], + } + render(, { wrapper: createWrapper() }) + + // BatchAction should be visible + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should show re-index option for error documents', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + documents: [ + createMockDoc({ id: 'doc-1', display_status: 'error' }), + ], + } + render(, { wrapper: createWrapper() }) + + // BatchAction with re-index should be present for error documents + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Row Click Navigation', () => { + it('should navigate to document detail when row is clicked', () => { + render(, { wrapper: createWrapper() }) + + const rows = screen.getAllByRole('row') + // First row is header, second row is first document + if (rows.length > 1) { + fireEvent.click(rows[1]) + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1') + } + }) + }) + + describe('Rename Modal', () => { + it('should not show rename modal initially', () => { + render(, { wrapper: createWrapper() }) + + // RenameModal should not be visible initially + const modal = screen.queryByRole('dialog') + expect(modal).not.toBeInTheDocument() + }) + + it('should show rename modal when rename button is clicked', () => { + const { container } = render(, { wrapper: createWrapper() }) + + // Find and click the rename button in the first row + const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md') + if (renameButtons.length > 0) { + fireEvent.click(renameButtons[0]) + } + + // After clicking rename, the modal should potentially be visible + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should call onUpdate when document is renamed', () => { + const onUpdate = vi.fn() + const props = { ...defaultProps, onUpdate } + render(, { wrapper: createWrapper() }) + + // The handleRenamed callback wraps onUpdate + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Edit Metadata Modal', () => { + it('should handle edit metadata action', () => { + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + } + render(, { wrapper: createWrapper() }) + + const editButton = screen.queryByRole('button', { name: /metadata/i }) + if (editButton) { + fireEvent.click(editButton) + } + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should call onManageMetadata when manage metadata is triggered', () => { + const onManageMetadata = vi.fn() + const props = { + ...defaultProps, + selectedIds: ['doc-1'], + onManageMetadata, + } + render(, { wrapper: createWrapper() }) + + // The onShowManage callback in EditMetadataBatchModal should call hideEditModal then onManageMetadata + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Chunking Mode', () => { + it('should render with general mode', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render with QA mode', () => { + // This test uses the default mock which returns ChunkingMode.text + // The component will compute isQAMode based on doc_form + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render with parent-child mode', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle empty documents array', () => { + const props = { ...defaultProps, documents: [] } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle documents with missing optional fields', () => { + const docWithMissingFields = createMockDoc({ + word_count: undefined as unknown as number, + hit_count: undefined as unknown as number, + }) + const props = { + ...defaultProps, + documents: [docWithMissingFields], + } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle status filter value', () => { + const props = { + ...defaultProps, + statusFilterValue: 'completed', + } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle remote sort value', () => { + const props = { + ...defaultProps, + remoteSortValue: 'created_at', + } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should handle large number of documents', () => { + const manyDocs = Array.from({ length: 20 }, (_, i) => + createMockDoc({ id: `doc-${i}`, name: `Document ${i}.txt` })) + const props = { ...defaultProps, documents: manyDocs } + render(, { wrapper: createWrapper() }) + + expect(screen.getByRole('table')).toBeInTheDocument() + }, 10000) + }) +}) diff --git a/web/app/components/datasets/documents/components/document-list/index.tsx b/web/app/components/datasets/documents/components/document-list/index.tsx new file mode 100644 index 0000000000..46fd7a02d5 --- /dev/null +++ b/web/app/components/datasets/documents/components/document-list/index.tsx @@ -0,0 +1,3 @@ +// Re-export from parent for backwards compatibility +export { default } from '../list' +export { renderTdValue } from './components' diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx index f63d6d987e..3106f6c30b 100644 --- a/web/app/components/datasets/documents/components/list.tsx +++ b/web/app/components/datasets/documents/components/list.tsx @@ -1,67 +1,26 @@ 'use client' import type { FC } from 'react' import type { Props as PaginationProps } from '@/app/components/base/pagination' -import type { CommonResponse } from '@/models/common' -import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, OnlineDriveInfo, SimpleDocumentDetail } from '@/models/datasets' -import { - RiArrowDownLine, - RiEditLine, - RiGlobalLine, -} from '@remixicon/react' +import type { SimpleDocumentDetail } from '@/models/datasets' import { useBoolean } from 'ahooks' -import { uniq } from 'es-toolkit/array' -import { pick } from 'es-toolkit/object' -import { useRouter } from 'next/navigation' import * as React from 'react' -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' -import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon' -import NotionIcon from '@/app/components/base/notion-icon' import Pagination from '@/app/components/base/pagination' -import Toast from '@/app/components/base/toast' -import Tooltip from '@/app/components/base/tooltip' -import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label' -import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter' -import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type' import EditMetadataBatchModal from '@/app/components/datasets/metadata/edit-metadata-batch/modal' import useBatchEditDocumentMetadata from '@/app/components/datasets/metadata/hooks/use-batch-edit-document-metadata' import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '@/context/dataset-detail' -import useTimestamp from '@/hooks/use-timestamp' -import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets' -import { DatasourceType } from '@/models/pipeline' -import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useDocumentDisable, useDocumentDownloadZip, useDocumentEnable, useDocumentSummary } from '@/service/knowledge/use-document' -import { asyncRunSafe } from '@/utils' -import { cn } from '@/utils/classnames' -import { downloadBlob } from '@/utils/download' -import { formatNumber } from '@/utils/format' +import { ChunkingMode, DocumentActionType } from '@/models/datasets' import BatchAction from '../detail/completed/common/batch-action' -import SummaryStatus from '../detail/completed/common/summary-status' -import StatusItem from '../status-item' import s from '../style.module.css' -import Operations from './operations' +import { DocumentTableRow, renderTdValue, SortHeader } from './document-list/components' +import { useDocumentActions, useDocumentSelection, useDocumentSort } from './document-list/hooks' import RenameModal from './rename-modal' -export const renderTdValue = (value: string | number | null, isEmptyStyle = false) => { - return ( -
- {value ?? '-'} -
- ) -} - -const renderCount = (count: number | undefined) => { - if (!count) - return renderTdValue(0, true) - - if (count < 1000) - return count - - return `${formatNumber((count / 1000).toFixed(1))}k` -} - type LocalDoc = SimpleDocumentDetail & { percent?: number } -type IDocumentListProps = { + +type DocumentListProps = { embeddingAvailable: boolean documents: LocalDoc[] selectedIds: string[] @@ -77,7 +36,7 @@ type IDocumentListProps = { /** * Document list component including basic information */ -const DocumentList: FC = ({ +const DocumentList: FC = ({ embeddingAvailable, documents = [], selectedIds, @@ -90,20 +49,43 @@ const DocumentList: FC = ({ remoteSortValue, }) => { const { t } = useTranslation() - const { formatTime } = useTimestamp() - const router = useRouter() const datasetConfig = useDatasetDetailContext(s => s.dataset) const chunkingMode = datasetConfig?.doc_form const isGeneralMode = chunkingMode !== ChunkingMode.parentChild const isQAMode = chunkingMode === ChunkingMode.qa - const [sortField, setSortField] = useState<'name' | 'word_count' | 'hit_count' | 'created_at' | null>(null) - const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc') - useEffect(() => { - setSortField(null) - setSortOrder('desc') - }, [remoteSortValue]) + // Sorting + const { sortField, sortOrder, handleSort, sortedDocuments } = useDocumentSort({ + documents, + statusFilterValue, + remoteSortValue, + }) + // Selection + const { + isAllSelected, + isSomeSelected, + onSelectAll, + onSelectOne, + hasErrorDocumentsSelected, + downloadableSelectedIds, + clearSelection, + } = useDocumentSelection({ + documents: sortedDocuments, + selectedIds, + onSelectedIdChange, + }) + + // Actions + const { handleAction, handleBatchReIndex, handleBatchDownload } = useDocumentActions({ + datasetId, + selectedIds, + downloadableSelectedIds, + onUpdate, + onClearSelection: clearSelection, + }) + + // Batch edit metadata const { isShowEditModal, showEditModal, @@ -113,233 +95,26 @@ const DocumentList: FC = ({ } = useBatchEditDocumentMetadata({ datasetId, docList: documents.filter(doc => selectedIds.includes(doc.id)), - selectedDocumentIds: selectedIds, // Pass all selected IDs separately + selectedDocumentIds: selectedIds, onUpdate, }) - const localDocs = useMemo(() => { - let filteredDocs = documents - - if (statusFilterValue && statusFilterValue !== 'all') { - filteredDocs = filteredDocs.filter(doc => - typeof doc.display_status === 'string' - && normalizeStatusForQuery(doc.display_status) === statusFilterValue, - ) - } - - if (!sortField) - return filteredDocs - - const sortedDocs = [...filteredDocs].sort((a, b) => { - let aValue: any - let bValue: any - - switch (sortField) { - case 'name': - aValue = a.name?.toLowerCase() || '' - bValue = b.name?.toLowerCase() || '' - break - case 'word_count': - aValue = a.word_count || 0 - bValue = b.word_count || 0 - break - case 'hit_count': - aValue = a.hit_count || 0 - bValue = b.hit_count || 0 - break - case 'created_at': - aValue = a.created_at - bValue = b.created_at - break - default: - return 0 - } - - if (sortField === 'name') { - const result = aValue.localeCompare(bValue) - return sortOrder === 'asc' ? result : -result - } - else { - const result = aValue - bValue - return sortOrder === 'asc' ? result : -result - } - }) - - return sortedDocs - }, [documents, sortField, sortOrder, statusFilterValue]) - - const handleSort = (field: 'name' | 'word_count' | 'hit_count' | 'created_at') => { - if (sortField === field) { - setSortOrder(sortOrder === 'asc' ? 'desc' : 'asc') - } - else { - setSortField(field) - setSortOrder('desc') - } - } - - const renderSortHeader = (field: 'name' | 'word_count' | 'hit_count' | 'created_at', label: string) => { - const isActive = sortField === field - const isDesc = isActive && sortOrder === 'desc' - - return ( -
handleSort(field)}> - {label} - -
- ) - } - + // Rename modal const [currDocument, setCurrDocument] = useState(null) const [isShowRenameModal, { setTrue: setShowRenameModalTrue, setFalse: setShowRenameModalFalse, }] = useBoolean(false) + const handleShowRenameModal = useCallback((doc: LocalDoc) => { setCurrDocument(doc) setShowRenameModalTrue() }, [setShowRenameModalTrue]) + const handleRenamed = useCallback(() => { onUpdate() }, [onUpdate]) - const isAllSelected = useMemo(() => { - return localDocs.length > 0 && localDocs.every(doc => selectedIds.includes(doc.id)) - }, [localDocs, selectedIds]) - - const isSomeSelected = useMemo(() => { - return localDocs.some(doc => selectedIds.includes(doc.id)) - }, [localDocs, selectedIds]) - - const onSelectedAll = useCallback(() => { - if (isAllSelected) - onSelectedIdChange([]) - else - onSelectedIdChange(uniq([...selectedIds, ...localDocs.map(doc => doc.id)])) - }, [isAllSelected, localDocs, onSelectedIdChange, selectedIds]) - const { mutateAsync: archiveDocument } = useDocumentArchive() - const { mutateAsync: generateSummary } = useDocumentSummary() - const { mutateAsync: enableDocument } = useDocumentEnable() - const { mutateAsync: disableDocument } = useDocumentDisable() - const { mutateAsync: deleteDocument } = useDocumentDelete() - const { mutateAsync: retryIndexDocument } = useDocumentBatchRetryIndex() - const { mutateAsync: requestDocumentsZip, isPending: isDownloadingZip } = useDocumentDownloadZip() - - const handleAction = (actionName: DocumentActionType) => { - return async () => { - let opApi - switch (actionName) { - case DocumentActionType.archive: - opApi = archiveDocument - break - case DocumentActionType.summary: - opApi = generateSummary - break - case DocumentActionType.enable: - opApi = enableDocument - break - case DocumentActionType.disable: - opApi = disableDocument - break - default: - opApi = deleteDocument - break - } - const [e] = await asyncRunSafe(opApi({ datasetId, documentIds: selectedIds }) as Promise) - - if (!e) { - if (actionName === DocumentActionType.delete) - onSelectedIdChange([]) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - onUpdate() - } - else { Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) } - } - } - - const handleBatchReIndex = async () => { - const [e] = await asyncRunSafe(retryIndexDocument({ datasetId, documentIds: selectedIds })) - if (!e) { - onSelectedIdChange([]) - Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - onUpdate() - } - else { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) - } - } - - const hasErrorDocumentsSelected = useMemo(() => { - return localDocs.some(doc => selectedIds.includes(doc.id) && doc.display_status === 'error') - }, [localDocs, selectedIds]) - - const getFileExtension = useCallback((fileName: string): string => { - if (!fileName) - return '' - const parts = fileName.split('.') - if (parts.length <= 1 || (parts[0] === '' && parts.length === 2)) - return '' - - return parts[parts.length - 1].toLowerCase() - }, []) - - const isCreateFromRAGPipeline = useCallback((createdFrom: string) => { - return createdFrom === 'rag-pipeline' - }, []) - - /** - * Calculate the data source type - * DataSourceType: FILE, NOTION, WEB (legacy) - * DatasourceType: localFile, onlineDocument, websiteCrawl, onlineDrive (new) - */ - const isLocalFile = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.localFile || dataSourceType === DataSourceType.FILE - }, []) - const isOnlineDocument = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.onlineDocument || dataSourceType === DataSourceType.NOTION - }, []) - const isWebsiteCrawl = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.websiteCrawl || dataSourceType === DataSourceType.WEB - }, []) - const isOnlineDrive = useCallback((dataSourceType: DataSourceType | DatasourceType) => { - return dataSourceType === DatasourceType.onlineDrive - }, []) - - const downloadableSelectedIds = useMemo(() => { - const selectedSet = new Set(selectedIds) - return localDocs - .filter(doc => selectedSet.has(doc.id) && doc.data_source_type === DataSourceType.FILE) - .map(doc => doc.id) - }, [localDocs, selectedIds]) - - /** - * Generate a random ZIP filename for bulk document downloads. - * We intentionally avoid leaking dataset info in the exported archive name. - */ - const generateDocsZipFileName = useCallback((): string => { - // Prefer UUID for uniqueness; fall back to time+random when unavailable. - const randomPart = (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') - ? crypto.randomUUID() - : `${Date.now().toString(36)}${Math.random().toString(36).slice(2, 10)}` - return `${randomPart}-docs.zip` - }, []) - - const handleBatchDownload = useCallback(async () => { - if (isDownloadingZip) - return - - // Download as a single ZIP to avoid browser caps on multiple automatic downloads. - const [e, blob] = await asyncRunSafe(requestDocumentsZip({ datasetId, documentIds: downloadableSelectedIds })) - if (e || !blob) { - Toast.notify({ type: 'error', message: t('actionMsg.downloadUnsuccessfully', { ns: 'common' }) }) - return - } - - downloadBlob({ data: blob, fileName: generateDocsZipFileName() }) - }, [datasetId, downloadableSelectedIds, generateDocsZipFileName, isDownloadingZip, requestDocumentsZip, t]) - return (
@@ -353,157 +128,76 @@ const DocumentList: FC = ({ className="mr-2 shrink-0" checked={isAllSelected} indeterminate={!isAllSelected && isSomeSelected} - onCheck={onSelectedAll} + onCheck={onSelectAll} /> )} #
- {renderSortHeader('name', t('list.table.header.fileName', { ns: 'datasetDocuments' }))} + {t('list.table.header.chunkingMode', { ns: 'datasetDocuments' })} - {renderSortHeader('word_count', t('list.table.header.words', { ns: 'datasetDocuments' }))} + - {renderSortHeader('hit_count', t('list.table.header.hitCount', { ns: 'datasetDocuments' }))} + - {renderSortHeader('created_at', t('list.table.header.uploadTime', { ns: 'datasetDocuments' }))} + {t('list.table.header.status', { ns: 'datasetDocuments' })} {t('list.table.header.action', { ns: 'datasetDocuments' })} - {localDocs.map((doc, index) => { - const isFile = isLocalFile(doc.data_source_type) - const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : '' - return ( - { - router.push(`/datasets/${datasetId}/documents/${doc.id}`) - }} - > - -
e.stopPropagation()}> - { - onSelectedIdChange( - selectedIds.includes(doc.id) - ? selectedIds.filter(id => id !== doc.id) - : [...selectedIds, doc.id], - ) - }} - /> - {index + 1} -
- - -
-
- {isOnlineDocument(doc.data_source_type) && ( - - )} - {isLocalFile(doc.data_source_type) && ( - - )} - {isOnlineDrive(doc.data_source_type) && ( - - )} - {isWebsiteCrawl(doc.data_source_type) && ( - - )} -
- - {doc.name} - - { - doc.summary_index_status && ( -
- -
- ) - } -
- -
{ - e.stopPropagation() - handleShowRenameModal(doc) - }} - > - -
-
-
-
- - - - - {renderCount(doc.word_count)} - {renderCount(doc.hit_count)} - - {formatTime(doc.created_at, t('dateTimeFormat', { ns: 'datasetHitTesting' }) as string)} - - - - - - - - - ) - })} + {sortedDocuments.map((doc, index) => ( + + ))}
- {(selectedIds.length > 0) && ( + + {selectedIds.length > 0 && ( = ({ onBatchDelete={handleAction(DocumentActionType.delete)} onEditMetadata={showEditModal} onBatchReIndex={hasErrorDocumentsSelected ? handleBatchReIndex : undefined} - onCancel={() => { - onSelectedIdChange([]) - }} + onCancel={clearSelection} /> )} - {/* Show Pagination only if the total is more than the limit */} + {!!pagination.total && ( = ({ } export default DocumentList + +export { renderTdValue } diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index d3dcc23121..cdd694fad9 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover' import Switch from '@/app/components/base/switch' import { ToastContext } from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' import { useDocumentArchive, @@ -263,10 +264,14 @@ const Operations = ({ {t('list.action.sync', { ns: 'datasetDocuments' })}
)} -
onOperate('summary')}> - - {t('list.action.summary', { ns: 'datasetDocuments' })} -
+ { + IS_CE_EDITION && ( +
onOperate('summary')}> + + {t('list.action.summary', { ns: 'datasetDocuments' })} +
+ ) + } )} diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx new file mode 100644 index 0000000000..7754ba6970 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.spec.tsx @@ -0,0 +1,351 @@ +import type { FileListItemProps } from './file-list-item' +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' +import FileListItem from './file-list-item' + +// Mock theme hook - can be changed per test +let mockTheme = 'light' +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: mockTheme }), +})) + +// Mock theme types +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock SimplePieChart with dynamic import handling +vi.mock('next/dynamic', () => ({ + default: () => { + const DynamicComponent = ({ percentage, stroke, fill }: { percentage: number, stroke: string, fill: string }) => ( +
+ Pie Chart: + {' '} + {percentage} + % +
+ ) + DynamicComponent.displayName = 'SimplePieChart' + return DynamicComponent + }, +})) + +// Mock DocumentFileIcon +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ name, extension, size }: { name: string, extension: string, size: string }) => ( +
+ Document Icon +
+ ), +})) + +describe('FileListItem', () => { + const createMockFile = (overrides: Partial = {}): File => ({ + name: 'test-document.pdf', + size: 1024 * 100, // 100KB + type: 'application/pdf', + lastModified: Date.now(), + ...overrides, + } as File) + + const createMockFileItem = (overrides: Partial = {}): FileItem => ({ + fileID: 'file-123', + file: createMockFile(overrides.file as Partial), + progress: PROGRESS_NOT_STARTED, + ...overrides, + }) + + const defaultProps: FileListItemProps = { + fileItem: createMockFileItem(), + onPreview: vi.fn(), + onRemove: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the file item container', () => { + const { container } = render() + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('flex', 'h-12', 'items-center', 'rounded-lg') + }) + + it('should render document icon with correct props', () => { + render() + + const icon = screen.getByTestId('document-icon') + expect(icon).toBeInTheDocument() + expect(icon).toHaveAttribute('data-name', 'test-document.pdf') + expect(icon).toHaveAttribute('data-extension', 'pdf') + expect(icon).toHaveAttribute('data-size', 'lg') + }) + + it('should render file name', () => { + render() + + expect(screen.getByText('test-document.pdf')).toBeInTheDocument() + }) + + it('should render file extension in uppercase via CSS class', () => { + render() + + // Extension is rendered in lowercase but styled with uppercase CSS + const extensionSpan = screen.getByText('pdf') + expect(extensionSpan).toBeInTheDocument() + expect(extensionSpan).toHaveClass('uppercase') + }) + + it('should render file size', () => { + render() + + // 100KB (102400 bytes) formatted with formatFileSize + expect(screen.getByText('100.00 KB')).toBeInTheDocument() + }) + + it('should render delete button', () => { + const { container } = render() + + const deleteButton = container.querySelector('.cursor-pointer') + expect(deleteButton).toBeInTheDocument() + }) + }) + + describe('progress states', () => { + it('should show progress chart when uploading (0-99)', () => { + const fileItem = createMockFileItem({ progress: 50 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toBeInTheDocument() + expect(pieChart).toHaveAttribute('data-percentage', '50') + }) + + it('should show progress chart at 0%', () => { + const fileItem = createMockFileItem({ progress: 0 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-percentage', '0') + }) + + it('should not show progress chart when complete (100)', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + + it('should not show progress chart when not started (-1)', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_NOT_STARTED }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('error state', () => { + it('should show error icon when progress is PROGRESS_ERROR', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_ERROR }) + const { container } = render() + + const errorIcon = container.querySelector('.text-text-destructive') + expect(errorIcon).toBeInTheDocument() + }) + + it('should apply error styling to container', () => { + const fileItem = createMockFileItem({ progress: PROGRESS_ERROR }) + const { container } = render() + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('border-state-destructive-border', 'bg-state-destructive-hover') + }) + + it('should not show error styling when not in error state', () => { + const { container } = render() + + const item = container.firstChild as HTMLElement + expect(item).not.toHaveClass('border-state-destructive-border') + }) + }) + + describe('theme handling', () => { + it('should use correct chart color for light theme', () => { + mockTheme = 'light' + const fileItem = createMockFileItem({ progress: 50 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#296dff') + expect(pieChart).toHaveAttribute('data-fill', '#296dff') + }) + + it('should use correct chart color for dark theme', () => { + mockTheme = 'dark' + const fileItem = createMockFileItem({ progress: 50 }) + render() + + const pieChart = screen.getByTestId('pie-chart') + expect(pieChart).toHaveAttribute('data-stroke', '#5289ff') + expect(pieChart).toHaveAttribute('data-fill', '#5289ff') + }) + }) + + describe('event handlers', () => { + it('should call onPreview when item is clicked', () => { + const onPreview = vi.fn() + const fileItem = createMockFileItem() + render() + + const item = screen.getByText('test-document.pdf').closest('[class*="flex h-12"]')! + fireEvent.click(item) + + expect(onPreview).toHaveBeenCalledTimes(1) + expect(onPreview).toHaveBeenCalledWith(fileItem.file) + }) + + it('should call onRemove when delete button is clicked', () => { + const onRemove = vi.fn() + const fileItem = createMockFileItem() + const { container } = render() + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onRemove).toHaveBeenCalledWith('file-123') + }) + + it('should stop propagation when delete button is clicked', () => { + const onPreview = vi.fn() + const onRemove = vi.fn() + const { container } = render() + + const deleteButton = container.querySelector('.cursor-pointer')! + fireEvent.click(deleteButton) + + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onPreview).not.toHaveBeenCalled() + }) + }) + + describe('file type handling', () => { + it('should handle files with multiple dots in name', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'my.document.file.docx' }), + }) + render() + + expect(screen.getByText('my.document.file.docx')).toBeInTheDocument() + // Extension is lowercase with uppercase CSS class + expect(screen.getByText('docx')).toBeInTheDocument() + }) + + it('should handle files without extension', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: 'README' }), + }) + render() + + // getFileType returns 'README' when there's no extension (last part after split) + expect(screen.getAllByText('README')).toHaveLength(2) // filename and extension + }) + + it('should handle various file extensions', () => { + const extensions = ['txt', 'md', 'json', 'csv', 'xlsx'] + + extensions.forEach((ext) => { + const fileItem = createMockFileItem({ + file: createMockFile({ name: `file.${ext}` }), + }) + const { unmount } = render() + // Extension is rendered in lowercase with uppercase CSS class + expect(screen.getByText(ext)).toBeInTheDocument() + unmount() + }) + }) + }) + + describe('file size display', () => { + it('should display size in KB for small files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 }), // 5KB + }) + render() + + expect(screen.getByText('5.00 KB')).toBeInTheDocument() + }) + + it('should display size in MB for larger files', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 5 * 1024 * 1024 }), // 5MB + }) + render() + + expect(screen.getByText('5.00 MB')).toBeInTheDocument() + }) + + it('should display size at threshold (10KB)', () => { + const fileItem = createMockFileItem({ + file: createMockFile({ size: 10 * 1024 }), // 10KB + }) + render() + + expect(screen.getByText('10.00 KB')).toBeInTheDocument() + }) + }) + + describe('upload progress values', () => { + it('should show chart at progress 1', () => { + const fileItem = createMockFileItem({ progress: 1 }) + render() + + expect(screen.getByTestId('pie-chart')).toBeInTheDocument() + }) + + it('should show chart at progress 99', () => { + const fileItem = createMockFileItem({ progress: 99 }) + render() + + expect(screen.getByTestId('pie-chart')).toHaveAttribute('data-percentage', '99') + }) + + it('should not show chart at progress 100', () => { + const fileItem = createMockFileItem({ progress: 100 }) + render() + + expect(screen.queryByTestId('pie-chart')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have proper shadow styling', () => { + const { container } = render() + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('shadow-xs') + }) + + it('should have proper border styling', () => { + const { container } = render() + + const item = container.firstChild as HTMLElement + expect(item).toHaveClass('border', 'border-components-panel-border') + }) + + it('should truncate long file names', () => { + const longFileName = 'this-is-a-very-long-file-name-that-should-be-truncated.pdf' + const fileItem = createMockFileItem({ + file: createMockFile({ name: longFileName }), + }) + render() + + const nameElement = screen.getByText(longFileName) + expect(nameElement).toHaveClass('truncate') + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx new file mode 100644 index 0000000000..1a61fa04f0 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/file-list-item.tsx @@ -0,0 +1,85 @@ +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { RiDeleteBinLine, RiErrorWarningFill } from '@remixicon/react' +import dynamic from 'next/dynamic' +import { useMemo } from 'react' +import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' +import { getFileType } from '@/app/components/datasets/common/image-uploader/utils' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { cn } from '@/utils/classnames' +import { formatFileSize } from '@/utils/format' +import { PROGRESS_ERROR } from '../constants' + +const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) + +export type FileListItemProps = { + fileItem: FileItem + onPreview: (file: File) => void + onRemove: (fileID: string) => void +} + +const FileListItem = ({ + fileItem, + onPreview, + onRemove, +}: FileListItemProps) => { + const { theme } = useTheme() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + + const isUploading = fileItem.progress >= 0 && fileItem.progress < 100 + const isError = fileItem.progress === PROGRESS_ERROR + + const handleClick = () => { + onPreview(fileItem.file) + } + + const handleRemove = (e: React.MouseEvent) => { + e.stopPropagation() + onRemove(fileItem.fileID) + } + + return ( +
+
+ +
+
+
+
{fileItem.file.name}
+
+
+ {getFileType(fileItem.file)} + · + {formatFileSize(fileItem.file.size)} +
+
+
+ {isUploading && ( + + )} + {isError && ( + + )} + + + +
+
+ ) +} + +export default FileListItem diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx new file mode 100644 index 0000000000..21742b731c --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.spec.tsx @@ -0,0 +1,231 @@ +import type { RefObject } from 'react' +import type { UploadDropzoneProps } from './upload-dropzone' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import UploadDropzone from './upload-dropzone' + +// Helper to create mock ref objects for testing +const createMockRef = (value: T | null = null): RefObject => ({ current: value }) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => { + const translations: Record = { + 'stepOne.uploader.button': 'Drag and drop files, or', + 'stepOne.uploader.buttonSingleFile': 'Drag and drop file, or', + 'stepOne.uploader.browse': 'Browse', + 'stepOne.uploader.tip': 'Supports {{supportTypes}}, Max {{size}}MB each, up to {{batchCount}} files at a time, {{totalCount}} files total', + } + let result = translations[key] || key + if (options && typeof options === 'object') { + Object.entries(options).forEach(([k, v]) => { + result = result.replace(`{{${k}}}`, String(v)) + }) + } + return result + }, + }), +})) + +describe('UploadDropzone', () => { + const defaultProps: UploadDropzoneProps = { + dropRef: createMockRef() as RefObject, + dragRef: createMockRef() as RefObject, + fileUploaderRef: createMockRef() as RefObject, + dragging: false, + supportBatchUpload: true, + supportTypesShowNames: 'PDF, DOCX, TXT', + fileUploadConfig: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + acceptTypes: ['.pdf', '.docx', '.txt'], + onSelectFile: vi.fn(), + onFileChange: vi.fn(), + allowedExtensions: ['pdf', 'docx', 'txt'], + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the dropzone container', () => { + const { container } = render() + + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render hidden file input', () => { + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toBeInTheDocument() + expect(input).toHaveClass('hidden') + expect(input).toHaveAttribute('type', 'file') + }) + + it('should render upload icon', () => { + render() + + const icon = document.querySelector('svg') + expect(icon).toBeInTheDocument() + }) + + it('should render browse label when extensions are allowed', () => { + render() + + expect(screen.getByText('Browse')).toBeInTheDocument() + }) + + it('should not render browse label when no extensions allowed', () => { + render() + + expect(screen.queryByText('Browse')).not.toBeInTheDocument() + }) + + it('should render file size and count limits', () => { + render() + + const tipText = screen.getByText(/Supports.*Max.*15MB/i) + expect(tipText).toBeInTheDocument() + }) + }) + + describe('file input configuration', () => { + it('should allow multiple files when supportBatchUpload is true', () => { + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('multiple') + }) + + it('should not allow multiple files when supportBatchUpload is false', () => { + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).not.toHaveAttribute('multiple') + }) + + it('should set accept attribute with correct types', () => { + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('accept', '.pdf,.docx') + }) + }) + + describe('text content', () => { + it('should show batch upload text when supportBatchUpload is true', () => { + render() + + expect(screen.getByText(/Drag and drop files/i)).toBeInTheDocument() + }) + + it('should show single file text when supportBatchUpload is false', () => { + render() + + expect(screen.getByText(/Drag and drop file/i)).toBeInTheDocument() + }) + }) + + describe('dragging state', () => { + it('should apply dragging styles when dragging is true', () => { + const { container } = render() + + const dropzone = container.querySelector('[class*="border-components-dropzone-border-accent"]') + expect(dropzone).toBeInTheDocument() + }) + + it('should render drag overlay when dragging', () => { + const dragRef = createMockRef() + render(} />) + + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).toBeInTheDocument() + }) + + it('should not render drag overlay when not dragging', () => { + render() + + const overlay = document.querySelector('.absolute.left-0.top-0') + expect(overlay).not.toBeInTheDocument() + }) + }) + + describe('event handlers', () => { + it('should call onSelectFile when browse label is clicked', () => { + const onSelectFile = vi.fn() + render() + + const browseLabel = screen.getByText('Browse') + fireEvent.click(browseLabel) + + expect(onSelectFile).toHaveBeenCalledTimes(1) + }) + + it('should call onFileChange when files are selected', () => { + const onFileChange = vi.fn() + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + fireEvent.change(input, { target: { files: [file] } }) + + expect(onFileChange).toHaveBeenCalledTimes(1) + }) + }) + + describe('refs', () => { + it('should attach dropRef to drop container', () => { + const dropRef = createMockRef() + render(} />) + + expect(dropRef.current).toBeInstanceOf(HTMLDivElement) + }) + + it('should attach fileUploaderRef to input element', () => { + const fileUploaderRef = createMockRef() + render(} />) + + expect(fileUploaderRef.current).toBeInstanceOf(HTMLInputElement) + }) + + it('should attach dragRef to overlay when dragging', () => { + const dragRef = createMockRef() + render(} />) + + expect(dragRef.current).toBeInstanceOf(HTMLDivElement) + }) + }) + + describe('styling', () => { + it('should have base dropzone styling', () => { + const { container } = render() + + const dropzone = container.querySelector('[class*="border-dashed"]') + expect(dropzone).toBeInTheDocument() + expect(dropzone).toHaveClass('rounded-xl') + }) + + it('should have cursor-pointer on browse label', () => { + render() + + const browseLabel = screen.getByText('Browse') + expect(browseLabel).toHaveClass('cursor-pointer') + }) + }) + + describe('accessibility', () => { + it('should have an accessible file input', () => { + render() + + const input = document.getElementById('fileUploader') as HTMLInputElement + expect(input).toHaveAttribute('id', 'fileUploader') + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx new file mode 100644 index 0000000000..66bf42d365 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/components/upload-dropzone.tsx @@ -0,0 +1,83 @@ +import type { ChangeEvent, RefObject } from 'react' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { cn } from '@/utils/classnames' + +type FileUploadConfig = { + file_size_limit: number + batch_count_limit: number + file_upload_limit: number +} + +export type UploadDropzoneProps = { + dropRef: RefObject + dragRef: RefObject + fileUploaderRef: RefObject + dragging: boolean + supportBatchUpload: boolean + supportTypesShowNames: string + fileUploadConfig: FileUploadConfig + acceptTypes: string[] + onSelectFile: () => void + onFileChange: (e: ChangeEvent) => void + allowedExtensions: string[] +} + +const UploadDropzone = ({ + dropRef, + dragRef, + fileUploaderRef, + dragging, + supportBatchUpload, + supportTypesShowNames, + fileUploadConfig, + acceptTypes, + onSelectFile, + onFileChange, + allowedExtensions, +}: UploadDropzoneProps) => { + const { t } = useTranslation() + + return ( + <> + +
+
+ + + {supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} + {allowedExtensions.length > 0 && ( + + )} + +
+
+ {t('stepOne.uploader.tip', { + ns: 'datasetCreation', + size: fileUploadConfig.file_size_limit, + supportTypes: supportTypesShowNames, + batchCount: fileUploadConfig.batch_count_limit, + totalCount: fileUploadConfig.file_upload_limit, + })} +
+ {dragging &&
} +
+ + ) +} + +export default UploadDropzone diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts new file mode 100644 index 0000000000..cda2dae868 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/constants.ts @@ -0,0 +1,3 @@ +export const PROGRESS_NOT_STARTED = -1 +export const PROGRESS_ERROR = -2 +export const PROGRESS_COMPLETE = 100 diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx new file mode 100644 index 0000000000..6248b70506 --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.spec.tsx @@ -0,0 +1,911 @@ +import type { ReactNode } from 'react' +import type { CustomFile, FileItem } from '@/models/datasets' +import { act, render, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../constants' + +// Mock notify function - defined before mocks +const mockNotify = vi.fn() +const mockClose = vi.fn() + +// Mock ToastContext with factory function +vi.mock('@/app/components/base/toast', async () => { + const { createContext, useContext } = await import('use-context-selector') + const context = createContext({ notify: mockNotify, close: mockClose }) + return { + ToastContext: context, + useToastContext: () => useContext(context), + } +}) + +// Mock file uploader utils +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFileUploadErrorMessage: (e: Error, defaultMsg: string) => e.message || defaultMsg, +})) + +// Mock format utils used by the shared hook +vi.mock('@/utils/format', () => ({ + getFileExtension: (filename: string) => { + const parts = filename.split('.') + return parts[parts.length - 1] || '' + }, +})) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock locale context +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +// Mock i18n config +vi.mock('@/i18n-config/language', () => ({ + LanguagesSupported: ['en-US', 'zh-Hans'], +})) + +// Mock config +vi.mock('@/config', () => ({ + IS_CE_EDITION: false, +})) + +// Mock store functions +const mockSetLocalFileList = vi.fn() +const mockSetCurrentLocalFile = vi.fn() +const mockGetState = vi.fn(() => ({ + setLocalFileList: mockSetLocalFileList, + setCurrentLocalFile: mockSetCurrentLocalFile, +})) +const mockStore = { getState: mockGetState } + +vi.mock('../../store', () => ({ + useDataSourceStoreWithSelector: vi.fn((selector: (state: { localFileList: FileItem[] }) => FileItem[]) => + selector({ localFileList: [] }), + ), + useDataSourceStore: vi.fn(() => mockStore), +})) + +// Mock file upload config +vi.mock('@/service/use-common', () => ({ + useFileUploadConfig: vi.fn(() => ({ + data: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + })), + // Required by the shared useFileUpload hook + useFileSupportTypes: vi.fn(() => ({ + data: { + allowed_extensions: ['pdf', 'docx', 'txt'], + }, + })), +})) + +// Mock upload service +const mockUpload = vi.fn() +vi.mock('@/service/base', () => ({ + upload: (...args: unknown[]) => mockUpload(...args), +})) + +// Import after all mocks are set up +const { useLocalFileUpload } = await import('./use-local-file-upload') +const { ToastContext } = await import('@/app/components/base/toast') + +const createWrapper = () => { + return ({ children }: { children: ReactNode }) => ( + + {children} + + ) +} + +describe('useLocalFileUpload', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUpload.mockReset() + }) + + describe('initialization', () => { + it('should initialize with default values', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'docx'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.dragging).toBe(false) + expect(result.current.localFileList).toEqual([]) + expect(result.current.hideUpload).toBe(false) + }) + + it('should create refs for dropzone, drag area, and file uploader', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.dropRef).toBeDefined() + expect(result.current.dragRef).toBeDefined() + expect(result.current.fileUploaderRef).toBeDefined() + }) + + it('should compute acceptTypes from allowedExtensions', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'docx', 'txt'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.acceptTypes).toEqual(['.pdf', '.docx', '.txt']) + }) + + it('should compute supportTypesShowNames correctly', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'docx', 'md'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('PDF') + expect(result.current.supportTypesShowNames).toContain('DOCX') + expect(result.current.supportTypesShowNames).toContain('MARKDOWN') + }) + + it('should provide file upload config with defaults', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.file_size_limit).toBe(15) + expect(result.current.fileUploadConfig.batch_count_limit).toBe(5) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(10) + }) + }) + + describe('supportBatchUpload option', () => { + it('should use batch limits when supportBatchUpload is true', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'], supportBatchUpload: true }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.batch_count_limit).toBe(5) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(10) + }) + + it('should use single file limits when supportBatchUpload is false', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'], supportBatchUpload: false }), + { wrapper: createWrapper() }, + ) + + expect(result.current.fileUploadConfig.batch_count_limit).toBe(1) + expect(result.current.fileUploadConfig.file_upload_limit).toBe(1) + }) + }) + + describe('selectHandle', () => { + it('should trigger file input click', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockClick = vi.fn() + const mockInput = { click: mockClick } as unknown as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.selectHandle() + }) + + expect(mockClick).toHaveBeenCalled() + }) + + it('should handle null fileUploaderRef gracefully', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + expect(() => { + act(() => { + result.current.selectHandle() + }) + }).not.toThrow() + }) + }) + + describe('removeFile', () => { + it('should remove file from list', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.removeFile('file-id-123') + }) + + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + + it('should clear file input value when removing', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockInput = { value: 'some-file.pdf' } as HTMLInputElement + Object.defineProperty(result.current.fileUploaderRef, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.removeFile('file-id') + }) + + expect(mockInput.value).toBe('') + }) + }) + + describe('handlePreview', () => { + it('should set current local file when file has id', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = { id: 'file-123', name: 'test.pdf', size: 1024 } + + act(() => { + result.current.handlePreview(mockFile as unknown as CustomFile) + }) + + expect(mockSetCurrentLocalFile).toHaveBeenCalledWith(mockFile) + }) + + it('should not set current file when file has no id', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = { name: 'test.pdf', size: 1024 } + + act(() => { + result.current.handlePreview(mockFile as unknown as CustomFile) + }) + + expect(mockSetCurrentLocalFile).not.toHaveBeenCalled() + }) + }) + + describe('fileChangeHandle', () => { + it('should handle valid files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + }) + + it('should handle empty file list', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const event = { + target: { + files: null, + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockSetLocalFileList).not.toHaveBeenCalled() + }) + + it('should reject files with invalid type', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.exe', { type: 'application/exe' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should reject files exceeding size limit', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + // Create a mock file larger than 15MB + const largeSize = 20 * 1024 * 1024 + const mockFile = new File([''], 'large.pdf', { type: 'application/pdf' }) + Object.defineProperty(mockFile, 'size', { value: largeSize }) + + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + + it('should limit files to batch count limit', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + // Create 10 files but batch limit is 5 + const files = Array.from({ length: 10 }, (_, i) => + new File(['content'], `file${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { + files, + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + + // Should only process first 5 files (batch_count_limit) + const firstCall = mockSetLocalFileList.mock.calls[0] + expect(firstCall[0].length).toBeLessThanOrEqual(5) + }) + }) + + describe('upload handling', () => { + it('should handle successful upload', async () => { + const uploadedResponse = { id: 'server-file-id' } + mockUpload.mockResolvedValue(uploadedResponse) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + }) + + it('should handle upload error', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + }) + }) + + it('should call upload with correct parameters', async () => { + mockUpload.mockResolvedValue({ id: 'file-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalledWith( + expect.objectContaining({ + xhr: expect.any(XMLHttpRequest), + data: expect.any(FormData), + }), + false, + undefined, + '?source=datasets', + ) + }) + }) + }) + + describe('extension mapping', () => { + it('should map md to markdown', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['md'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('MARKDOWN') + }) + + it('should map htm to html', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['htm'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('HTML') + }) + + it('should preserve unmapped extensions', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'txt'] }), + { wrapper: createWrapper() }, + ) + + expect(result.current.supportTypesShowNames).toContain('PDF') + expect(result.current.supportTypesShowNames).toContain('TXT') + }) + + it('should remove duplicate extensions', () => { + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf', 'pdf', 'PDF'] }), + { wrapper: createWrapper() }, + ) + + const count = (result.current.supportTypesShowNames.match(/PDF/g) || []).length + expect(count).toBe(1) + }) + }) + + describe('drag and drop handlers', () => { + // Helper component that renders with the hook and connects refs + const TestDropzone = ({ allowedExtensions, supportBatchUpload = true }: { + allowedExtensions: string[] + supportBatchUpload?: boolean + }) => { + const { + dropRef, + dragRef, + dragging, + } = useLocalFileUpload({ allowedExtensions, supportBatchUpload }) + + return ( +
+
+ {dragging &&
} +
+ {String(dragging)} +
+ ) + } + + it('should set dragging true on dragenter', async () => { + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + }) + + it('should handle dragover event', async () => { + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + await act(async () => { + const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragOverEvent) + }) + + // dragover should not throw + expect(dropzone).toBeInTheDocument() + }) + + it('should set dragging false on dragleave from drag overlay', async () => { + const { getByTestId, queryByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + + // First trigger dragenter to set dragging true + await act(async () => { + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + dropzone.dispatchEvent(dragEnterEvent) + }) + + expect(getByTestId('dragging').textContent).toBe('true') + + // Now the drag overlay should be rendered + const dragOverlay = queryByTestId('drag-overlay') + if (dragOverlay) { + await act(async () => { + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'target', { value: dragOverlay }) + dropzone.dispatchEvent(dragLeaveEvent) + }) + } + }) + + it('should handle drop with files', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { + dataTransfer: { items: DataTransferItem[], files: File[] } | null + } + // Mock dataTransfer with items array (used by the shared hook for directory traversal) + dropEvent.dataTransfer = { + items: [{ + kind: 'file', + getAsFile: () => mockFile, + }] as unknown as DataTransferItem[], + files: [mockFile], + } + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + }) + }) + + it('should handle drop without dataTransfer', async () => { + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + mockSetLocalFileList.mockClear() + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { dataTransfer: { files: File[] } | null } + dropEvent.dataTransfer = null + dropzone.dispatchEvent(dropEvent) + }) + + // Should not upload when no dataTransfer + expect(mockSetLocalFileList).not.toHaveBeenCalled() + }) + + it('should limit to single file on drop when supportBatchUpload is false', async () => { + mockUpload.mockResolvedValue({ id: 'uploaded-id' }) + + const { getByTestId } = await act(async () => + render( + + + , + ), + ) + + const dropzone = getByTestId('dropzone') + const files = [ + new File(['content1'], 'test1.pdf', { type: 'application/pdf' }), + new File(['content2'], 'test2.pdf', { type: 'application/pdf' }), + ] + + await act(async () => { + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) as Event & { + dataTransfer: { items: DataTransferItem[], files: File[] } | null + } + // Mock dataTransfer with items array (used by the shared hook for directory traversal) + dropEvent.dataTransfer = { + items: files.map(f => ({ + kind: 'file', + getAsFile: () => f, + })) as unknown as DataTransferItem[], + files, + } + dropzone.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(mockSetLocalFileList).toHaveBeenCalled() + // Should only have 1 file (limited by supportBatchUpload: false) + const callArgs = mockSetLocalFileList.mock.calls[0][0] + expect(callArgs.length).toBe(1) + }) + }) + }) + + describe('file upload limit', () => { + it('should reject files exceeding total file upload limit', async () => { + // Mock store to return existing files + const { useDataSourceStoreWithSelector } = vi.mocked(await import('../../store')) + const existingFiles: FileItem[] = Array.from({ length: 8 }, (_, i) => ({ + fileID: `existing-${i}`, + file: { name: `existing-${i}.pdf`, size: 1024 } as CustomFile, + progress: 100, + })) + vi.mocked(useDataSourceStoreWithSelector).mockImplementation(selector => + selector({ localFileList: existingFiles } as Parameters[0]), + ) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + // Try to add 5 more files when limit is 10 and we already have 8 + const files = Array.from({ length: 5 }, (_, i) => + new File(['content'], `new-${i}.pdf`, { type: 'application/pdf' })) + + const event = { + target: { files }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + // Should show error about files number limit + expect(mockNotify).toHaveBeenCalledWith( + expect.objectContaining({ type: 'error' }), + ) + + // Reset mock for other tests + vi.mocked(useDataSourceStoreWithSelector).mockImplementation(selector => + selector({ localFileList: [] as FileItem[] } as Parameters[0]), + ) + }) + }) + + describe('upload progress tracking', () => { + it('should track upload progress', async () => { + let progressCallback: ((e: ProgressEvent) => void) | undefined + + mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => { + progressCallback = options.onprogress + return { id: 'uploaded-id' } + }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + + // Simulate progress event + if (progressCallback) { + act(() => { + progressCallback!({ + lengthComputable: true, + loaded: 50, + total: 100, + } as ProgressEvent) + }) + + expect(mockSetLocalFileList).toHaveBeenCalled() + } + }) + + it('should not update progress when not lengthComputable', async () => { + let progressCallback: ((e: ProgressEvent) => void) | undefined + const uploadCallCount = { value: 0 } + + mockUpload.mockImplementation(async (options: { onprogress: (e: ProgressEvent) => void }) => { + progressCallback = options.onprogress + uploadCallCount.value++ + return { id: 'uploaded-id' } + }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { files: [mockFile] }, + } as unknown as React.ChangeEvent + + mockSetLocalFileList.mockClear() + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + expect(mockUpload).toHaveBeenCalled() + }) + + const callsBeforeProgress = mockSetLocalFileList.mock.calls.length + + // Simulate progress event without lengthComputable + if (progressCallback) { + act(() => { + progressCallback!({ + lengthComputable: false, + loaded: 50, + total: 100, + } as ProgressEvent) + }) + + // Should not have additional calls + expect(mockSetLocalFileList.mock.calls.length).toBe(callsBeforeProgress) + } + }) + }) + + describe('file progress constants', () => { + it('should use PROGRESS_NOT_STARTED for new files', async () => { + mockUpload.mockResolvedValue({ id: 'file-id' }) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const callArgs = mockSetLocalFileList.mock.calls[0][0] + expect(callArgs[0].progress).toBe(PROGRESS_NOT_STARTED) + }) + }) + + it('should set PROGRESS_ERROR on upload failure', async () => { + mockUpload.mockRejectedValue(new Error('Upload failed')) + + const { result } = renderHook( + () => useLocalFileUpload({ allowedExtensions: ['pdf'] }), + { wrapper: createWrapper() }, + ) + + const mockFile = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const event = { + target: { + files: [mockFile], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle(event) + }) + + await waitFor(() => { + const calls = mockSetLocalFileList.mock.calls + const lastCall = calls[calls.length - 1][0] + expect(lastCall.some((f: FileItem) => f.progress === PROGRESS_ERROR)).toBe(true) + }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts new file mode 100644 index 0000000000..1f7c9ecfed --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/use-local-file-upload.ts @@ -0,0 +1,105 @@ +import type { CustomFile as File, FileItem } from '@/models/datasets' +import { produce } from 'immer' +import { useCallback, useRef } from 'react' +import { useFileUpload } from '@/app/components/datasets/create/file-uploader/hooks/use-file-upload' +import { useDataSourceStore, useDataSourceStoreWithSelector } from '../../store' + +export type UseLocalFileUploadOptions = { + allowedExtensions: string[] + supportBatchUpload?: boolean +} + +/** + * Hook for handling local file uploads in the create-from-pipeline flow. + * This is a thin wrapper around the generic useFileUpload hook that provides + * Zustand store integration for state management. + */ +export const useLocalFileUpload = ({ + allowedExtensions, + supportBatchUpload = true, +}: UseLocalFileUploadOptions) => { + const localFileList = useDataSourceStoreWithSelector(state => state.localFileList) + const dataSourceStore = useDataSourceStore() + const fileListRef = useRef([]) + + // Sync fileListRef with localFileList for internal tracking + fileListRef.current = localFileList + + const prepareFileList = useCallback((files: FileItem[]) => { + const { setLocalFileList } = dataSourceStore.getState() + setLocalFileList(files) + fileListRef.current = files + }, [dataSourceStore]) + + const onFileUpdate = useCallback((fileItem: FileItem, progress: number, list: FileItem[]) => { + const { setLocalFileList } = dataSourceStore.getState() + const newList = produce(list, (draft) => { + const targetIndex = draft.findIndex(file => file.fileID === fileItem.fileID) + if (targetIndex !== -1) { + draft[targetIndex] = { + ...draft[targetIndex], + ...fileItem, + progress, + } + } + }) + setLocalFileList(newList) + }, [dataSourceStore]) + + const onFileListUpdate = useCallback((files: FileItem[]) => { + const { setLocalFileList } = dataSourceStore.getState() + setLocalFileList(files) + fileListRef.current = files + }, [dataSourceStore]) + + const onPreview = useCallback((file: File) => { + const { setCurrentLocalFile } = dataSourceStore.getState() + setCurrentLocalFile(file) + }, [dataSourceStore]) + + const { + dropRef, + dragRef, + fileUploaderRef, + dragging, + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } = useFileUpload({ + fileList: localFileList, + prepareFileList, + onFileUpdate, + onFileListUpdate, + onPreview, + supportBatchUpload, + allowedExtensions, + }) + + return { + // Refs + dropRef, + dragRef, + fileUploaderRef, + + // State + dragging, + localFileList, + + // Config + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + + // Handlers + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } +} diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx new file mode 100644 index 0000000000..66f13be84f --- /dev/null +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.spec.tsx @@ -0,0 +1,398 @@ +import type { FileItem } from '@/models/datasets' +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import LocalFile from './index' + +// Mock the hook +const mockUseLocalFileUpload = vi.fn() +vi.mock('./hooks/use-local-file-upload', () => ({ + useLocalFileUpload: (...args: unknown[]) => mockUseLocalFileUpload(...args), +})) + +// Mock react-i18next for sub-components +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock theme hook for sub-components +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ theme: 'light' }), +})) + +// Mock theme types +vi.mock('@/types/app', () => ({ + Theme: { dark: 'dark', light: 'light' }, +})) + +// Mock DocumentFileIcon +vi.mock('@/app/components/datasets/common/document-file-icon', () => ({ + default: ({ name }: { name: string }) =>
{name}
, +})) + +// Mock SimplePieChart +vi.mock('next/dynamic', () => ({ + default: () => { + const Component = ({ percentage }: { percentage: number }) => ( +
+ {percentage} + % +
+ ) + return Component + }, +})) + +describe('LocalFile', () => { + const mockDropRef = { current: null } + const mockDragRef = { current: null } + const mockFileUploaderRef = { current: null } + + const defaultHookReturn = { + dropRef: mockDropRef, + dragRef: mockDragRef, + fileUploaderRef: mockFileUploaderRef, + dragging: false, + localFileList: [] as FileItem[], + fileUploadConfig: { + file_size_limit: 15, + batch_count_limit: 5, + file_upload_limit: 10, + }, + acceptTypes: ['.pdf', '.docx'], + supportTypesShowNames: 'PDF, DOCX', + hideUpload: false, + selectHandle: vi.fn(), + fileChangeHandle: vi.fn(), + removeFile: vi.fn(), + handlePreview: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + mockUseLocalFileUpload.mockReturnValue(defaultHookReturn) + }) + + describe('rendering', () => { + it('should render the component container', () => { + const { container } = render( + , + ) + + expect(container.firstChild).toHaveClass('flex', 'flex-col') + }) + + it('should render UploadDropzone when hideUpload is false', () => { + render() + + const fileInput = document.getElementById('fileUploader') + expect(fileInput).toBeInTheDocument() + }) + + it('should not render UploadDropzone when hideUpload is true', () => { + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: true, + }) + + render() + + const fileInput = document.getElementById('fileUploader') + expect(fileInput).not.toBeInTheDocument() + }) + }) + + describe('file list rendering', () => { + it('should not render file list when empty', () => { + render() + + expect(screen.queryByTestId('document-icon')).not.toBeInTheDocument() + }) + + it('should render file list when files exist', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { + fileID: 'file-1', + file: mockFile, + progress: -1, + }, + ], + }) + + render() + + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + + it('should render multiple file items', () => { + const createMockFile = (name: string) => ({ + name, + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + }) as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'file-1', file: createMockFile('doc1.pdf'), progress: -1 }, + { fileID: 'file-2', file: createMockFile('doc2.pdf'), progress: -1 }, + { fileID: 'file-3', file: createMockFile('doc3.pdf'), progress: -1 }, + ], + }) + + render() + + const icons = screen.getAllByTestId('document-icon') + expect(icons).toHaveLength(3) + }) + + it('should use correct key for file items', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'unique-id-123', file: mockFile, progress: -1 }, + ], + }) + + render() + + // The component should render without errors (key is used internally) + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) + + describe('hook integration', () => { + it('should pass allowedExtensions to hook', () => { + render() + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith({ + allowedExtensions: ['pdf', 'docx', 'txt'], + supportBatchUpload: true, + }) + }) + + it('should pass supportBatchUpload true by default', () => { + render() + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith( + expect.objectContaining({ supportBatchUpload: true }), + ) + }) + + it('should pass supportBatchUpload false when specified', () => { + render() + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith( + expect.objectContaining({ supportBatchUpload: false }), + ) + }) + }) + + describe('props passed to UploadDropzone', () => { + it('should pass all required props to UploadDropzone', () => { + const selectHandle = vi.fn() + const fileChangeHandle = vi.fn() + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + selectHandle, + fileChangeHandle, + supportTypesShowNames: 'PDF, DOCX', + acceptTypes: ['.pdf', '.docx'], + fileUploadConfig: { + file_size_limit: 20, + batch_count_limit: 10, + file_upload_limit: 50, + }, + }) + + render() + + // Verify the dropzone is rendered with correct configuration + const fileInput = document.getElementById('fileUploader') + expect(fileInput).toBeInTheDocument() + expect(fileInput).toHaveAttribute('accept', '.pdf,.docx') + expect(fileInput).toHaveAttribute('multiple') + }) + }) + + describe('props passed to FileListItem', () => { + it('should pass correct props to file items', () => { + const handlePreview = vi.fn() + const removeFile = vi.fn() + const mockFile = { + name: 'document.pdf', + size: 2048, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + handlePreview, + removeFile, + localFileList: [ + { fileID: 'test-id', file: mockFile, progress: 50 }, + ], + }) + + render() + + expect(screen.getByTestId('document-icon')).toHaveTextContent('document.pdf') + }) + }) + + describe('conditional rendering', () => { + it('should show both dropzone and file list when files exist and hideUpload is false', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: false, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: -1 }, + ], + }) + + render() + + expect(document.getElementById('fileUploader')).toBeInTheDocument() + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + + it('should show only file list when hideUpload is true', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: true, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: -1 }, + ], + }) + + render() + + expect(document.getElementById('fileUploader')).not.toBeInTheDocument() + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) + + describe('file list container styling', () => { + it('should apply correct container classes for file list', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: -1 }, + ], + }) + + const { container } = render() + + const fileListContainer = container.querySelector('.mt-1.flex.flex-col.gap-y-1') + expect(fileListContainer).toBeInTheDocument() + }) + }) + + describe('edge cases', () => { + it('should handle empty allowedExtensions', () => { + render() + + expect(mockUseLocalFileUpload).toHaveBeenCalledWith({ + allowedExtensions: [], + supportBatchUpload: true, + }) + }) + + it('should handle files with same fileID but different index', () => { + const mockFile = { + name: 'test.pdf', + size: 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + localFileList: [ + { fileID: 'same-id', file: { ...mockFile, name: 'doc1.pdf' } as File, progress: -1 }, + { fileID: 'same-id', file: { ...mockFile, name: 'doc2.pdf' } as File, progress: -1 }, + ], + }) + + // Should render without key collision errors due to index in key + render() + + const icons = screen.getAllByTestId('document-icon') + expect(icons).toHaveLength(2) + }) + }) + + describe('component integration', () => { + it('should render complete component tree', () => { + const mockFile = { + name: 'complete-test.pdf', + size: 5 * 1024, + type: 'application/pdf', + lastModified: Date.now(), + } as File + + mockUseLocalFileUpload.mockReturnValue({ + ...defaultHookReturn, + hideUpload: false, + localFileList: [ + { fileID: 'file-1', file: mockFile, progress: 50 }, + ], + dragging: false, + }) + + const { container } = render( + , + ) + + // Main container + expect(container.firstChild).toHaveClass('flex', 'flex-col') + + // Dropzone exists + expect(document.getElementById('fileUploader')).toBeInTheDocument() + + // File list exists + expect(screen.getByTestId('document-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx index d02d5927f2..cb3632ba9d 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx @@ -1,26 +1,7 @@ 'use client' -import type { CustomFile as File, FileItem } from '@/models/datasets' -import { RiDeleteBinLine, RiErrorWarningFill, RiUploadCloud2Line } from '@remixicon/react' -import { produce } from 'immer' -import dynamic from 'next/dynamic' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import { ToastContext } from '@/app/components/base/toast' -import DocumentFileIcon from '@/app/components/datasets/common/document-file-icon' -import { IS_CE_EDITION } from '@/config' -import { useLocale } from '@/context/i18n' -import useTheme from '@/hooks/use-theme' -import { LanguagesSupported } from '@/i18n-config/language' -import { upload } from '@/service/base' -import { useFileUploadConfig } from '@/service/use-common' -import { Theme } from '@/types/app' -import { cn } from '@/utils/classnames' -import { useDataSourceStore, useDataSourceStoreWithSelector } from '../store' - -const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) +import FileListItem from './components/file-list-item' +import UploadDropzone from './components/upload-dropzone' +import { useLocalFileUpload } from './hooks/use-local-file-upload' export type LocalFileProps = { allowedExtensions: string[] @@ -31,345 +12,49 @@ const LocalFile = ({ allowedExtensions, supportBatchUpload = true, }: LocalFileProps) => { - const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const locale = useLocale() - const localFileList = useDataSourceStoreWithSelector(state => state.localFileList) - const dataSourceStore = useDataSourceStore() - const [dragging, setDragging] = useState(false) - - const dropRef = useRef(null) - const dragRef = useRef(null) - const fileUploader = useRef(null) - const fileListRef = useRef([]) - - const hideUpload = !supportBatchUpload && localFileList.length > 0 - - const { data: fileUploadConfigResponse } = useFileUploadConfig() - const supportTypesShowNames = useMemo(() => { - const extensionMap: { [key: string]: string } = { - md: 'markdown', - pptx: 'pptx', - htm: 'html', - xlsx: 'xlsx', - docx: 'docx', - } - - return allowedExtensions - .map(item => extensionMap[item] || item) // map to standardized extension - .map(item => item.toLowerCase()) // convert to lower case - .filter((item, index, self) => self.indexOf(item) === index) // remove duplicates - .map(item => item.toUpperCase()) // convert to upper case - .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') - }, [locale, allowedExtensions]) - const ACCEPTS = allowedExtensions.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => ({ - file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, - batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, - file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, - }), [fileUploadConfigResponse, supportBatchUpload]) - - const updateFile = useCallback((fileItem: FileItem, progress: number, list: FileItem[]) => { - const { setLocalFileList } = dataSourceStore.getState() - const newList = produce(list, (draft) => { - const targetIndex = draft.findIndex(file => file.fileID === fileItem.fileID) - draft[targetIndex] = { - ...draft[targetIndex], - progress, - } - }) - setLocalFileList(newList) - }, [dataSourceStore]) - - const updateFileList = useCallback((preparedFiles: FileItem[]) => { - const { setLocalFileList } = dataSourceStore.getState() - setLocalFileList(preparedFiles) - }, [dataSourceStore]) - - const handlePreview = useCallback((file: File) => { - const { setCurrentLocalFile } = dataSourceStore.getState() - if (file.id) - setCurrentLocalFile(file) - }, [dataSourceStore]) - - // utils - const getFileType = (currentFile: File) => { - if (!currentFile) - return '' - - const arr = currentFile.name.split('.') - return arr[arr.length - 1] - } - - const getFileSize = (size: number) => { - if (size / 1024 < 10) - return `${(size / 1024).toFixed(2)}KB` - - return `${(size / 1024 / 1024).toFixed(2)}MB` - } - - const isValid = useCallback((file: File) => { - const { size } = file - const ext = `.${getFileType(file)}` - const isValidType = ACCEPTS.includes(ext.toLowerCase()) - if (!isValidType) - notify({ type: 'error', message: t('stepOne.uploader.validation.typeError', { ns: 'datasetCreation' }) }) - - const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 - if (!isValidSize) - notify({ type: 'error', message: t('stepOne.uploader.validation.size', { ns: 'datasetCreation', size: fileUploadConfig.file_size_limit }) }) - - return isValidType && isValidSize - }, [notify, t, ACCEPTS, fileUploadConfig.file_size_limit]) - - type UploadResult = Awaited> - - const fileUpload = useCallback(async (fileItem: FileItem): Promise => { - const formData = new FormData() - formData.append('file', fileItem.file) - const onProgress = (e: ProgressEvent) => { - if (e.lengthComputable) { - const percent = Math.floor(e.loaded / e.total * 100) - updateFile(fileItem, percent, fileListRef.current) - } - } - - return upload({ - xhr: new XMLHttpRequest(), - data: formData, - onprogress: onProgress, - }, false, undefined, '?source=datasets') - .then((res: UploadResult) => { - const updatedFile = Object.assign({}, fileItem.file, { - id: res.id, - ...(res as Partial), - }) as File - const completeFile: FileItem = { - fileID: fileItem.fileID, - file: updatedFile, - progress: -1, - } - const index = fileListRef.current.findIndex(item => item.fileID === fileItem.fileID) - fileListRef.current[index] = completeFile - updateFile(completeFile, 100, fileListRef.current) - return Promise.resolve({ ...completeFile }) - }) - .catch((e) => { - const errorMessage = getFileUploadErrorMessage(e, t('stepOne.uploader.failed', { ns: 'datasetCreation' }), t) - notify({ type: 'error', message: errorMessage }) - updateFile(fileItem, -2, fileListRef.current) - return Promise.resolve({ ...fileItem }) - }) - .finally() - }, [fileListRef, notify, updateFile, t]) - - const uploadBatchFiles = useCallback((bFiles: FileItem[]) => { - bFiles.forEach(bf => (bf.progress = 0)) - return Promise.all(bFiles.map(fileUpload)) - }, [fileUpload]) - - const uploadMultipleFiles = useCallback(async (files: FileItem[]) => { - const batchCountLimit = fileUploadConfig.batch_count_limit - const length = files.length - let start = 0 - let end = 0 - - while (start < length) { - if (start + batchCountLimit > length) - end = length - else - end = start + batchCountLimit - const bFiles = files.slice(start, end) - await uploadBatchFiles(bFiles) - start = end - } - }, [fileUploadConfig, uploadBatchFiles]) - - const initialUpload = useCallback((files: File[]) => { - const filesCountLimit = fileUploadConfig.file_upload_limit - if (!files.length) - return false - - if (files.length + localFileList.length > filesCountLimit && !IS_CE_EDITION) { - notify({ type: 'error', message: t('stepOne.uploader.validation.filesNumber', { ns: 'datasetCreation', filesNumber: filesCountLimit }) }) - return false - } - - const preparedFiles = files.map((file, index) => ({ - fileID: `file${index}-${Date.now()}`, - file, - progress: -1, - })) - const newFiles = [...fileListRef.current, ...preparedFiles] - updateFileList(newFiles) - fileListRef.current = newFiles - uploadMultipleFiles(preparedFiles) - }, [fileUploadConfig.file_upload_limit, localFileList.length, updateFileList, uploadMultipleFiles, notify, t]) - - const handleDragEnter = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target !== dragRef.current) - setDragging(true) - } - const handleDragOver = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - } - const handleDragLeave = (e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - if (e.target === dragRef.current) - setDragging(false) - } - - const handleDrop = useCallback((e: DragEvent) => { - e.preventDefault() - e.stopPropagation() - setDragging(false) - if (!e.dataTransfer) - return - - let files = Array.from(e.dataTransfer.files) as File[] - if (!supportBatchUpload) - files = files.slice(0, 1) - - const validFiles = files.filter(isValid) - initialUpload(validFiles) - }, [initialUpload, isValid, supportBatchUpload]) - - const selectHandle = useCallback(() => { - if (fileUploader.current) - fileUploader.current.click() - }, []) - - const removeFile = (fileID: string) => { - if (fileUploader.current) - fileUploader.current.value = '' - - fileListRef.current = fileListRef.current.filter(item => item.fileID !== fileID) - updateFileList([...fileListRef.current]) - } - const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - let files = Array.from(e.target.files ?? []) as File[] - files = files.slice(0, fileUploadConfig.batch_count_limit) - initialUpload(files.filter(isValid)) - }, [isValid, initialUpload, fileUploadConfig.batch_count_limit]) - - const { theme } = useTheme() - const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) - - useEffect(() => { - const dropElement = dropRef.current - dropElement?.addEventListener('dragenter', handleDragEnter) - dropElement?.addEventListener('dragover', handleDragOver) - dropElement?.addEventListener('dragleave', handleDragLeave) - dropElement?.addEventListener('drop', handleDrop) - return () => { - dropElement?.removeEventListener('dragenter', handleDragEnter) - dropElement?.removeEventListener('dragover', handleDragOver) - dropElement?.removeEventListener('dragleave', handleDragLeave) - dropElement?.removeEventListener('drop', handleDrop) - } - }, [handleDrop]) + const { + dropRef, + dragRef, + fileUploaderRef, + dragging, + localFileList, + fileUploadConfig, + acceptTypes, + supportTypesShowNames, + hideUpload, + selectHandle, + fileChangeHandle, + removeFile, + handlePreview, + } = useLocalFileUpload({ allowedExtensions, supportBatchUpload }) return (
{!hideUpload && ( - )} - {!hideUpload && ( -
-
- - - - {supportBatchUpload ? t('stepOne.uploader.button', { ns: 'datasetCreation' }) : t('stepOne.uploader.buttonSingleFile', { ns: 'datasetCreation' })} - {allowedExtensions.length > 0 && ( - - )} - -
-
- {t('stepOne.uploader.tip', { - ns: 'datasetCreation', - size: fileUploadConfig.file_size_limit, - supportTypes: supportTypesShowNames, - batchCount: fileUploadConfig.batch_count_limit, - totalCount: fileUploadConfig.file_upload_limit, - })} -
- {dragging &&
} -
- )} {localFileList.length > 0 && (
- {localFileList.map((fileItem, index) => { - const isUploading = fileItem.progress >= 0 && fileItem.progress < 100 - const isError = fileItem.progress === -2 - return ( -
-
- -
-
-
-
{fileItem.file.name}
-
-
- {getFileType(fileItem.file)} - · - {getFileSize(fileItem.file.size)} -
-
-
- {isUploading && ( - - )} - { - isError && ( - - ) - } - { - e.stopPropagation() - removeFile(fileItem.fileID) - }} - > - - -
-
- ) - })} + {localFileList.map((fileItem, index) => ( + + ))}
)}
diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx index 486ba2ffdf..ca5a56ec2a 100644 --- a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx @@ -7,6 +7,7 @@ import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' +import { IS_CE_EDITION } from '@/config' import { cn } from '@/utils/classnames' const i18nPrefix = 'batchAction' @@ -87,7 +88,7 @@ const BatchAction: FC = ({ {t('metadata.metadata', { ns: 'dataset' })} )} - {onBatchSummary && ( + {onBatchSummary && IS_CE_EDITION && ( + )} + {isPaused && ( + + )} +
+ ) +}) + +StatusHeader.displayName = 'StatusHeader' + +export default StatusHeader diff --git a/web/app/components/datasets/documents/detail/embedding/hooks/index.ts b/web/app/components/datasets/documents/detail/embedding/hooks/index.ts new file mode 100644 index 0000000000..603c16dda5 --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/hooks/index.ts @@ -0,0 +1,10 @@ +export { + calculatePercent, + isEmbeddingStatus, + isTerminalStatus, + useEmbeddingStatus, + useInvalidateEmbeddingStatus, + usePauseIndexing, + useResumeIndexing, +} from './use-embedding-status' +export type { EmbeddingStatusType } from './use-embedding-status' diff --git a/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx new file mode 100644 index 0000000000..7cadc12dfc --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.spec.tsx @@ -0,0 +1,462 @@ +import type { ReactNode } from 'react' +import type { IndexingStatusResponse } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import * as datasetsService from '@/service/datasets' +import { + calculatePercent, + isEmbeddingStatus, + isTerminalStatus, + useEmbeddingStatus, + useInvalidateEmbeddingStatus, + usePauseIndexing, + useResumeIndexing, +} from './use-embedding-status' + +vi.mock('@/service/datasets') + +const mockFetchIndexingStatus = vi.mocked(datasetsService.fetchIndexingStatus) +const mockPauseDocIndexing = vi.mocked(datasetsService.pauseDocIndexing) +const mockResumeDocIndexing = vi.mocked(datasetsService.resumeDocIndexing) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const createWrapper = () => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + + {children} + + ) +} + +const mockIndexingStatus = (overrides: Partial = {}): IndexingStatusResponse => ({ + id: 'doc1', + indexing_status: 'indexing', + completed_segments: 50, + total_segments: 100, + processing_started_at: 0, + parsing_completed_at: 0, + cleaning_completed_at: 0, + splitting_completed_at: 0, + completed_at: null, + paused_at: null, + error: null, + stopped_at: null, + ...overrides, +}) + +describe('use-embedding-status', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('isEmbeddingStatus', () => { + it('should return true for indexing status', () => { + expect(isEmbeddingStatus('indexing')).toBe(true) + }) + + it('should return true for splitting status', () => { + expect(isEmbeddingStatus('splitting')).toBe(true) + }) + + it('should return true for parsing status', () => { + expect(isEmbeddingStatus('parsing')).toBe(true) + }) + + it('should return true for cleaning status', () => { + expect(isEmbeddingStatus('cleaning')).toBe(true) + }) + + it('should return false for completed status', () => { + expect(isEmbeddingStatus('completed')).toBe(false) + }) + + it('should return false for paused status', () => { + expect(isEmbeddingStatus('paused')).toBe(false) + }) + + it('should return false for error status', () => { + expect(isEmbeddingStatus('error')).toBe(false) + }) + + it('should return false for undefined', () => { + expect(isEmbeddingStatus(undefined)).toBe(false) + }) + + it('should return false for empty string', () => { + expect(isEmbeddingStatus('')).toBe(false) + }) + }) + + describe('isTerminalStatus', () => { + it('should return true for completed status', () => { + expect(isTerminalStatus('completed')).toBe(true) + }) + + it('should return true for error status', () => { + expect(isTerminalStatus('error')).toBe(true) + }) + + it('should return true for paused status', () => { + expect(isTerminalStatus('paused')).toBe(true) + }) + + it('should return false for indexing status', () => { + expect(isTerminalStatus('indexing')).toBe(false) + }) + + it('should return false for undefined', () => { + expect(isTerminalStatus(undefined)).toBe(false) + }) + }) + + describe('calculatePercent', () => { + it('should calculate percent correctly', () => { + expect(calculatePercent(50, 100)).toBe(50) + }) + + it('should return 0 when total is 0', () => { + expect(calculatePercent(50, 0)).toBe(0) + }) + + it('should return 0 when total is undefined', () => { + expect(calculatePercent(50, undefined)).toBe(0) + }) + + it('should return 0 when completed is undefined', () => { + expect(calculatePercent(undefined, 100)).toBe(0) + }) + + it('should cap at 100 when percent exceeds 100', () => { + expect(calculatePercent(150, 100)).toBe(100) + }) + + it('should round to nearest integer', () => { + expect(calculatePercent(33, 100)).toBe(33) + expect(calculatePercent(1, 3)).toBe(33) + }) + }) + + describe('useEmbeddingStatus', () => { + it('should return initial state when disabled', () => { + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1', enabled: false }), + { wrapper: createWrapper() }, + ) + + expect(result.current.isEmbedding).toBe(false) + expect(result.current.isCompleted).toBe(false) + expect(result.current.isPaused).toBe(false) + expect(result.current.isError).toBe(false) + expect(result.current.percent).toBe(0) + }) + + it('should not fetch when datasetId is missing', () => { + renderHook( + () => useEmbeddingStatus({ documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + expect(mockFetchIndexingStatus).not.toHaveBeenCalled() + }) + + it('should not fetch when documentId is missing', () => { + renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1' }), + { wrapper: createWrapper() }, + ) + + expect(mockFetchIndexingStatus).not.toHaveBeenCalled() + }) + + it('should fetch indexing status when enabled with valid ids', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isEmbedding).toBe(true) + }) + + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + expect(result.current.percent).toBe(50) + }) + + it('should set isCompleted when status is completed', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + indexing_status: 'completed', + completed_segments: 100, + })) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isCompleted).toBe(true) + }) + + expect(result.current.percent).toBe(100) + }) + + it('should set isPaused when status is paused', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + indexing_status: 'paused', + })) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isPaused).toBe(true) + }) + }) + + it('should set isError when status is error', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + indexing_status: 'error', + completed_segments: 25, + })) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isError).toBe(true) + }) + }) + + it('should provide invalidate function', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.isEmbedding).toBe(true) + }) + + expect(typeof result.current.invalidate).toBe('function') + + // Call invalidate should not throw + await act(async () => { + result.current.invalidate() + }) + }) + + it('should provide resetStatus function that clears data', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { result } = renderHook( + () => useEmbeddingStatus({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(result.current.data).toBeDefined() + }) + + // Reset status should clear the data + await act(async () => { + result.current.resetStatus() + }) + + await waitFor(() => { + expect(result.current.data).toBeNull() + }) + }) + }) + + describe('usePauseIndexing', () => { + it('should call pauseDocIndexing when mutate is called', async () => { + mockPauseDocIndexing.mockResolvedValue({ result: 'success' }) + + const { result } = renderHook( + () => usePauseIndexing({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(mockPauseDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + + it('should call onSuccess callback on successful pause', async () => { + mockPauseDocIndexing.mockResolvedValue({ result: 'success' }) + const onSuccess = vi.fn() + + const { result } = renderHook( + () => usePauseIndexing({ datasetId: 'ds1', documentId: 'doc1', onSuccess }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(onSuccess).toHaveBeenCalled() + }) + }) + + it('should call onError callback on failed pause', async () => { + const error = new Error('Network error') + mockPauseDocIndexing.mockRejectedValue(error) + const onError = vi.fn() + + const { result } = renderHook( + () => usePauseIndexing({ datasetId: 'ds1', documentId: 'doc1', onError }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(onError).toHaveBeenCalled() + expect(onError.mock.calls[0][0]).toEqual(error) + }) + }) + }) + + describe('useResumeIndexing', () => { + it('should call resumeDocIndexing when mutate is called', async () => { + mockResumeDocIndexing.mockResolvedValue({ result: 'success' }) + + const { result } = renderHook( + () => useResumeIndexing({ datasetId: 'ds1', documentId: 'doc1' }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(mockResumeDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + + it('should call onSuccess callback on successful resume', async () => { + mockResumeDocIndexing.mockResolvedValue({ result: 'success' }) + const onSuccess = vi.fn() + + const { result } = renderHook( + () => useResumeIndexing({ datasetId: 'ds1', documentId: 'doc1', onSuccess }), + { wrapper: createWrapper() }, + ) + + await act(async () => { + result.current.mutate() + }) + + await waitFor(() => { + expect(onSuccess).toHaveBeenCalled() + }) + }) + }) + + describe('useInvalidateEmbeddingStatus', () => { + it('should return a function', () => { + const { result } = renderHook( + () => useInvalidateEmbeddingStatus(), + { wrapper: createWrapper() }, + ) + + expect(typeof result.current).toBe('function') + }) + + it('should invalidate specific query when datasetId and documentId are provided', async () => { + const queryClient = createTestQueryClient() + const wrapper = ({ children }: { children: ReactNode }) => ( + + {children} + + ) + + // Set some initial data in the cache + queryClient.setQueryData(['embedding', 'indexing-status', 'ds1', 'doc1'], { + id: 'doc1', + indexing_status: 'indexing', + }) + + const { result } = renderHook( + () => useInvalidateEmbeddingStatus(), + { wrapper }, + ) + + await act(async () => { + result.current('ds1', 'doc1') + }) + + // The query should be invalidated (marked as stale) + const queryState = queryClient.getQueryState(['embedding', 'indexing-status', 'ds1', 'doc1']) + expect(queryState?.isInvalidated).toBe(true) + }) + + it('should invalidate all embedding status queries when ids are not provided', async () => { + const queryClient = createTestQueryClient() + const wrapper = ({ children }: { children: ReactNode }) => ( + + {children} + + ) + + // Set some initial data in the cache for multiple documents + queryClient.setQueryData(['embedding', 'indexing-status', 'ds1', 'doc1'], { + id: 'doc1', + indexing_status: 'indexing', + }) + queryClient.setQueryData(['embedding', 'indexing-status', 'ds2', 'doc2'], { + id: 'doc2', + indexing_status: 'completed', + }) + + const { result } = renderHook( + () => useInvalidateEmbeddingStatus(), + { wrapper }, + ) + + await act(async () => { + result.current() + }) + + // Both queries should be invalidated + const queryState1 = queryClient.getQueryState(['embedding', 'indexing-status', 'ds1', 'doc1']) + const queryState2 = queryClient.getQueryState(['embedding', 'indexing-status', 'ds2', 'doc2']) + expect(queryState1?.isInvalidated).toBe(true) + expect(queryState2?.isInvalidated).toBe(true) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts new file mode 100644 index 0000000000..e55cd8f9aa --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/hooks/use-embedding-status.ts @@ -0,0 +1,149 @@ +import type { CommonResponse } from '@/models/common' +import type { IndexingStatusResponse } from '@/models/datasets' +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' +import { useCallback, useEffect, useMemo, useRef } from 'react' +import { + fetchIndexingStatus, + pauseDocIndexing, + resumeDocIndexing, +} from '@/service/datasets' + +const NAME_SPACE = 'embedding' + +export type EmbeddingStatusType = 'indexing' | 'splitting' | 'parsing' | 'cleaning' | 'completed' | 'paused' | 'error' | 'waiting' | '' + +const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning'] as const +const TERMINAL_STATUSES = ['completed', 'error', 'paused'] as const + +export const isEmbeddingStatus = (status?: string): boolean => { + return EMBEDDING_STATUSES.includes(status as typeof EMBEDDING_STATUSES[number]) +} + +export const isTerminalStatus = (status?: string): boolean => { + return TERMINAL_STATUSES.includes(status as typeof TERMINAL_STATUSES[number]) +} + +export const calculatePercent = (completed?: number, total?: number): number => { + if (!total || total === 0) + return 0 + const percent = Math.round((completed || 0) * 100 / total) + return Math.min(percent, 100) +} + +type UseEmbeddingStatusOptions = { + datasetId?: string + documentId?: string + enabled?: boolean + onComplete?: () => void +} + +export const useEmbeddingStatus = ({ + datasetId, + documentId, + enabled = true, + onComplete, +}: UseEmbeddingStatusOptions) => { + const queryClient = useQueryClient() + const isPolling = useRef(false) + const onCompleteRef = useRef(onComplete) + onCompleteRef.current = onComplete + + const queryKey = useMemo( + () => [NAME_SPACE, 'indexing-status', datasetId, documentId] as const, + [datasetId, documentId], + ) + + const query = useQuery({ + queryKey, + queryFn: () => fetchIndexingStatus({ datasetId: datasetId!, documentId: documentId! }), + enabled: enabled && !!datasetId && !!documentId, + refetchInterval: (query) => { + const status = query.state.data?.indexing_status + if (isTerminalStatus(status)) { + return false + } + return 2500 + }, + refetchOnWindowFocus: false, + }) + + const status = query.data?.indexing_status || '' + const isEmbedding = isEmbeddingStatus(status) + const isCompleted = status === 'completed' + const isPaused = status === 'paused' + const isError = status === 'error' + const percent = calculatePercent(query.data?.completed_segments, query.data?.total_segments) + + // Handle completion callback + useEffect(() => { + if (isTerminalStatus(status) && isPolling.current) { + isPolling.current = false + onCompleteRef.current?.() + } + if (isEmbedding) { + isPolling.current = true + } + }, [status, isEmbedding]) + + const invalidate = useCallback(() => { + queryClient.invalidateQueries({ queryKey }) + }, [queryClient, queryKey]) + + const resetStatus = useCallback(() => { + queryClient.setQueryData(queryKey, null) + }, [queryClient, queryKey]) + + return { + data: query.data, + isLoading: query.isLoading, + isEmbedding, + isCompleted, + isPaused, + isError, + percent, + invalidate, + resetStatus, + refetch: query.refetch, + } +} + +type UsePauseResumeOptions = { + datasetId?: string + documentId?: string + onSuccess?: () => void + onError?: (error: Error) => void +} + +export const usePauseIndexing = ({ datasetId, documentId, onSuccess, onError }: UsePauseResumeOptions) => { + return useMutation({ + mutationKey: [NAME_SPACE, 'pause', datasetId, documentId], + mutationFn: () => pauseDocIndexing({ datasetId: datasetId!, documentId: documentId! }), + onSuccess, + onError, + }) +} + +export const useResumeIndexing = ({ datasetId, documentId, onSuccess, onError }: UsePauseResumeOptions) => { + return useMutation({ + mutationKey: [NAME_SPACE, 'resume', datasetId, documentId], + mutationFn: () => resumeDocIndexing({ datasetId: datasetId!, documentId: documentId! }), + onSuccess, + onError, + }) +} + +export const useInvalidateEmbeddingStatus = () => { + const queryClient = useQueryClient() + return useCallback((datasetId?: string, documentId?: string) => { + if (datasetId && documentId) { + queryClient.invalidateQueries({ + queryKey: [NAME_SPACE, 'indexing-status', datasetId, documentId], + }) + } + else { + queryClient.invalidateQueries({ + queryKey: [NAME_SPACE, 'indexing-status'], + }) + } + }, [queryClient]) +} diff --git a/web/app/components/datasets/documents/detail/embedding/index.spec.tsx b/web/app/components/datasets/documents/detail/embedding/index.spec.tsx new file mode 100644 index 0000000000..699de4f12a --- /dev/null +++ b/web/app/components/datasets/documents/detail/embedding/index.spec.tsx @@ -0,0 +1,337 @@ +import type { ReactNode } from 'react' +import type { DocumentContextValue } from '../context' +import type { IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ProcessMode } from '@/models/datasets' +import * as datasetsService from '@/service/datasets' +import * as useDataset from '@/service/knowledge/use-dataset' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import { DocumentContext } from '../context' +import EmbeddingDetail from './index' + +vi.mock('@/service/datasets') +vi.mock('@/service/knowledge/use-dataset') + +const mockFetchIndexingStatus = vi.mocked(datasetsService.fetchIndexingStatus) +const mockPauseDocIndexing = vi.mocked(datasetsService.pauseDocIndexing) +const mockResumeDocIndexing = vi.mocked(datasetsService.resumeDocIndexing) +const mockUseProcessRule = vi.mocked(useDataset.useProcessRule) + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + mutations: { retry: false }, + }, +}) + +const createWrapper = (contextValue: DocumentContextValue = { datasetId: 'ds1', documentId: 'doc1' }) => { + const queryClient = createTestQueryClient() + return ({ children }: { children: ReactNode }) => ( + + + {children} + + + ) +} + +const mockIndexingStatus = (overrides: Partial = {}): IndexingStatusResponse => ({ + id: 'doc1', + indexing_status: 'indexing', + completed_segments: 50, + total_segments: 100, + processing_started_at: Date.now(), + parsing_completed_at: 0, + cleaning_completed_at: 0, + splitting_completed_at: 0, + completed_at: null, + paused_at: null, + error: null, + stopped_at: null, + ...overrides, +}) + +const mockProcessRule = (overrides: Partial = {}): ProcessRuleResponse => ({ + mode: ProcessMode.general, + rules: { + segmentation: { separator: '\n', max_tokens: 500, chunk_overlap: 50 }, + pre_processing_rules: [{ id: 'remove_extra_spaces', enabled: true }], + parent_mode: 'full-doc', + subchunk_segmentation: { separator: '\n', max_tokens: 200, chunk_overlap: 20 }, + }, + limits: { indexing_max_segmentation_tokens_length: 4000 }, + ...overrides, +}) + +describe('EmbeddingDetail', () => { + const defaultProps = { + detailUpdate: vi.fn(), + indexingType: IndexingType.QUALIFIED, + retrievalMethod: RETRIEVE_METHOD.semantic, + } + + beforeEach(() => { + vi.clearAllMocks() + + mockUseProcessRule.mockReturnValue({ + data: mockProcessRule(), + isLoading: false, + error: null, + } as ReturnType) + }) + + describe('Rendering', () => { + it('should render without crashing', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.processing/i)).toBeInTheDocument() + }) + }) + + it('should render with provided datasetId and documentId props', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + , + { wrapper: createWrapper({ datasetId: '', documentId: '' }) }, + ) + + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'custom-ds', + documentId: 'custom-doc', + }) + }) + }) + + it('should fall back to context values when props are not provided', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + }) + + describe('Status Display', () => { + it('should show processing status when indexing', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'indexing' })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.processing/i)).toBeInTheDocument() + }) + }) + + it('should show completed status', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'completed' })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.completed/i)).toBeInTheDocument() + }) + }) + + it('should show paused status', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'paused' })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.paused/i)).toBeInTheDocument() + }) + }) + + it('should show error status', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'error' })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.error/i)).toBeInTheDocument() + }) + }) + }) + + describe('Progress Display', () => { + it('should display segment progress', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ + completed_segments: 50, + total_segments: 100, + })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/50\/100/)).toBeInTheDocument() + expect(screen.getByText(/50%/)).toBeInTheDocument() + }) + }) + }) + + describe('Pause/Resume Actions', () => { + it('should show pause button when embedding is in progress', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'indexing' })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.pause/i)).toBeInTheDocument() + }) + }) + + it('should show resume button when paused', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'paused' })) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.resume/i)).toBeInTheDocument() + }) + }) + + it('should call pause API when pause button is clicked', async () => { + const user = userEvent.setup() + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'indexing' })) + mockPauseDocIndexing.mockResolvedValue({ result: 'success' }) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.pause/i)).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /pause/i })) + + await waitFor(() => { + expect(mockPauseDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + + it('should call resume API when resume button is clicked', async () => { + const user = userEvent.setup() + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus({ indexing_status: 'paused' })) + mockResumeDocIndexing.mockResolvedValue({ result: 'success' }) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/embedding\.resume/i)).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /resume/i })) + + await waitFor(() => { + expect(mockResumeDocIndexing).toHaveBeenCalledWith({ + datasetId: 'ds1', + documentId: 'doc1', + }) + }) + }) + }) + + describe('Rule Detail', () => { + it('should display rule detail section', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render(, { wrapper: createWrapper() }) + + await waitFor(() => { + expect(screen.getByText(/stepTwo\.indexMode/i)).toBeInTheDocument() + }) + }) + + it('should display qualified index mode', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + , + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(screen.getByText(/stepTwo\.qualified/i)).toBeInTheDocument() + }) + }) + + it('should display economical index mode', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + , + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(screen.getByText(/stepTwo\.economical/i)).toBeInTheDocument() + }) + }) + }) + + describe('detailUpdate Callback', () => { + it('should call detailUpdate when status becomes terminal', async () => { + const detailUpdate = vi.fn() + // First call returns indexing, subsequent call returns completed + mockFetchIndexingStatus + .mockResolvedValueOnce(mockIndexingStatus({ indexing_status: 'indexing' })) + .mockResolvedValueOnce(mockIndexingStatus({ indexing_status: 'completed' })) + + render( + , + { wrapper: createWrapper() }, + ) + + // Wait for the terminal status to trigger detailUpdate + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalled() + }, { timeout: 5000 }) + }) + }) + + describe('Edge Cases', () => { + it('should handle missing context values', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + render( + , + { wrapper: createWrapper({ datasetId: undefined, documentId: undefined }) }, + ) + + await waitFor(() => { + expect(mockFetchIndexingStatus).toHaveBeenCalledWith({ + datasetId: 'explicit-ds', + documentId: 'explicit-doc', + }) + }) + }) + + it('should render skeleton component', async () => { + mockFetchIndexingStatus.mockResolvedValue(mockIndexingStatus()) + + const { container } = render(, { wrapper: createWrapper() }) + + // EmbeddingSkeleton should be rendered - check for the skeleton wrapper element + await waitFor(() => { + const skeletonWrapper = container.querySelector('.bg-dataset-chunk-list-mask-bg') + expect(skeletonWrapper).toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/embedding/index.tsx b/web/app/components/datasets/documents/detail/embedding/index.tsx index 37b5bb85e7..e89a85c6de 100644 --- a/web/app/components/datasets/documents/detail/embedding/index.tsx +++ b/web/app/components/datasets/documents/detail/embedding/index.tsx @@ -1,31 +1,18 @@ import type { FC } from 'react' -import type { CommonResponse } from '@/models/common' -import type { IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' -import { RiLoader2Line, RiPauseCircleLine, RiPlayCircleLine } from '@remixicon/react' -import Image from 'next/image' +import type { IndexingType } from '../../../create/step-two' +import type { RETRIEVE_METHOD } from '@/types/app' import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import Divider from '@/app/components/base/divider' import { ToastContext } from '@/app/components/base/toast' -import { ProcessMode } from '@/models/datasets' -import { - fetchIndexingStatus as doFetchIndexingStatus, - pauseDocIndexing, - resumeDocIndexing, -} from '@/service/datasets' import { useProcessRule } from '@/service/knowledge/use-dataset' -import { RETRIEVE_METHOD } from '@/types/app' -import { asyncRunSafe, sleep } from '@/utils' -import { cn } from '@/utils/classnames' -import { indexMethodIcon, retrievalIcon } from '../../../create/icons' -import { IndexingType } from '../../../create/step-two' import { useDocumentContext } from '../context' -import { FieldInfo } from '../metadata' +import { ProgressBar, RuleDetail, SegmentProgress, StatusHeader } from './components' +import { useEmbeddingStatus, usePauseIndexing, useResumeIndexing } from './hooks' import EmbeddingSkeleton from './skeleton' -type IEmbeddingDetailProps = { +type EmbeddingDetailProps = { datasetId?: string documentId?: string indexingType?: IndexingType @@ -33,128 +20,7 @@ type IEmbeddingDetailProps = { detailUpdate: VoidFunction } -type IRuleDetailProps = { - sourceData?: ProcessRuleResponse - indexingType?: IndexingType - retrievalMethod?: RETRIEVE_METHOD -} - -const RuleDetail: FC = React.memo(({ - sourceData, - indexingType, - retrievalMethod, -}) => { - const { t } = useTranslation() - - const segmentationRuleMap = { - mode: t('embedding.mode', { ns: 'datasetDocuments' }), - segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }), - textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }), - } - - const getRuleName = (key: string) => { - if (key === 'remove_extra_spaces') - return t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }) - - if (key === 'remove_urls_emails') - return t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }) - - if (key === 'remove_stopwords') - return t('stepTwo.removeStopwords', { ns: 'datasetCreation' }) - } - - const isNumber = (value: unknown) => { - return typeof value === 'number' - } - - const getValue = useCallback((field: string) => { - let value: string | number | undefined = '-' - const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens) - ? sourceData.rules.segmentation.max_tokens - : value - const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens) - ? sourceData.rules.subchunk_segmentation.max_tokens - : value - switch (field) { - case 'mode': - value = !sourceData?.mode - ? value - : sourceData.mode === ProcessMode.general - ? (t('embedding.custom', { ns: 'datasetDocuments' }) as string) - : `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${sourceData?.rules?.parent_mode === 'paragraph' - ? t('parentMode.paragraph', { ns: 'dataset' }) - : t('parentMode.fullDoc', { ns: 'dataset' })}` - break - case 'segmentLength': - value = !sourceData?.mode - ? value - : sourceData.mode === ProcessMode.general - ? maxTokens - : `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}` - break - default: - value = !sourceData?.mode - ? value - : sourceData?.rules?.pre_processing_rules?.filter(rule => - rule.enabled).map(rule => getRuleName(rule.id)).join(',') - break - } - return value - }, [sourceData]) - - return ( -
-
- {Object.keys(segmentationRuleMap).map((field) => { - return ( - - ) - })} -
- - - )} - /> - - )} - /> -
- ) -}) - -RuleDetail.displayName = 'RuleDetail' - -const EmbeddingDetail: FC = ({ +const EmbeddingDetail: FC = ({ datasetId: dstId, documentId: docId, detailUpdate, @@ -164,144 +30,95 @@ const EmbeddingDetail: FC = ({ const { t } = useTranslation() const { notify } = useContext(ToastContext) - const datasetId = useDocumentContext(s => s.datasetId) - const documentId = useDocumentContext(s => s.documentId) - const localDatasetId = dstId ?? datasetId - const localDocumentId = docId ?? documentId + const contextDatasetId = useDocumentContext(s => s.datasetId) + const contextDocumentId = useDocumentContext(s => s.documentId) + const datasetId = dstId ?? contextDatasetId + const documentId = docId ?? contextDocumentId - const [indexingStatusDetail, setIndexingStatusDetail] = useState(null) - const fetchIndexingStatus = async () => { - const status = await doFetchIndexingStatus({ datasetId: localDatasetId, documentId: localDocumentId }) - setIndexingStatusDetail(status) - return status - } + const { + data: indexingStatus, + isEmbedding, + isCompleted, + isPaused, + isError, + percent, + resetStatus, + refetch, + } = useEmbeddingStatus({ + datasetId, + documentId, + onComplete: detailUpdate, + }) - const isStopQuery = useRef(false) - const stopQueryStatus = useCallback(() => { - isStopQuery.current = true - }, []) + const { data: ruleDetail } = useProcessRule(documentId) - const startQueryStatus = useCallback(async () => { - if (isStopQuery.current) - return + const handleSuccess = useCallback(() => { + notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + }, [notify, t]) - try { - const indexingStatusDetail = await fetchIndexingStatus() - if (['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status)) { - stopQueryStatus() - detailUpdate() - return - } + const handleError = useCallback(() => { + notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + }, [notify, t]) - await sleep(2500) - await startQueryStatus() - } - catch { - await sleep(2500) - await startQueryStatus() - } - }, [stopQueryStatus]) + const pauseMutation = usePauseIndexing({ + datasetId, + documentId, + onSuccess: () => { + handleSuccess() + resetStatus() + }, + onError: handleError, + }) - useEffect(() => { - isStopQuery.current = false - startQueryStatus() - return () => { - stopQueryStatus() - } - }, [startQueryStatus, stopQueryStatus]) + const resumeMutation = useResumeIndexing({ + datasetId, + documentId, + onSuccess: () => { + handleSuccess() + refetch() + detailUpdate() + }, + onError: handleError, + }) - const { data: ruleDetail } = useProcessRule(localDocumentId) + const handlePause = useCallback(() => { + pauseMutation.mutate() + }, [pauseMutation]) - const isEmbedding = useMemo(() => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const isEmbeddingCompleted = useMemo(() => ['completed'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const isEmbeddingPaused = useMemo(() => ['paused'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const isEmbeddingError = useMemo(() => ['error'].includes(indexingStatusDetail?.indexing_status || ''), [indexingStatusDetail]) - const percent = useMemo(() => { - const completedCount = indexingStatusDetail?.completed_segments || 0 - const totalCount = indexingStatusDetail?.total_segments || 0 - if (totalCount === 0) - return 0 - const percent = Math.round(completedCount * 100 / totalCount) - return percent > 100 ? 100 : percent - }, [indexingStatusDetail]) - - const handleSwitch = async () => { - const opApi = isEmbedding ? pauseDocIndexing : resumeDocIndexing - const [e] = await asyncRunSafe(opApi({ datasetId: localDatasetId, documentId: localDocumentId }) as Promise) - if (!e) { - notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - // if the embedding is resumed from paused, we need to start the query status - if (isEmbeddingPaused) { - isStopQuery.current = false - startQueryStatus() - detailUpdate() - } - setIndexingStatusDetail(null) - } - else { - notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) - } - } + const handleResume = useCallback(() => { + resumeMutation.mutate() + }, [resumeMutation]) return ( <>
-
- {isEmbedding && } - - {isEmbedding && t('embedding.processing', { ns: 'datasetDocuments' })} - {isEmbeddingCompleted && t('embedding.completed', { ns: 'datasetDocuments' })} - {isEmbeddingPaused && t('embedding.paused', { ns: 'datasetDocuments' })} - {isEmbeddingError && t('embedding.error', { ns: 'datasetDocuments' })} - - {isEmbedding && ( - - )} - {isEmbeddingPaused && ( - - )} -
- {/* progress bar */} -
-
-
-
- - {`${t('embedding.segments', { ns: 'datasetDocuments' })} ${indexingStatusDetail?.completed_segments || '--'}/${indexingStatusDetail?.total_segments || '--'} · ${percent}%`} - -
- + + + +
diff --git a/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx b/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx index 3ae95ec531..04ebd16f6c 100644 --- a/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx +++ b/web/app/components/datasets/formatted-text/flavours/edit-slice.tsx @@ -3,8 +3,6 @@ import type { FC, ReactNode } from 'react' import type { SliceProps } from './type' import { autoUpdate, flip, FloatingFocusManager, offset, shift, useDismiss, useFloating, useHover, useInteractions, useRole } from '@floating-ui/react' import { RiDeleteBinLine } from '@remixicon/react' -// @ts-expect-error no types available -import lineClamp from 'line-clamp' import { useState } from 'react' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import { cn } from '@/utils/classnames' @@ -58,12 +56,8 @@ export const EditSlice: FC = (props) => { <> { - refs.setReference(ref) - if (ref) - lineClamp(ref, 4) - }} + className={cn('mr-0 line-clamp-4 block', className)} + ref={refs.setReference} {...getReferenceProps()} > ({ + default: ({ icon, className }: { icon?: string, className?: string }) => ( +
{icon}
+ ), +})) + // Mock useFormatTimeFromNow hook vi.mock('@/hooks/use-format-time-from-now', () => ({ useFormatTimeFromNow: () => ({ diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx index ebee72159e..607830661d 100644 --- a/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.spec.tsx @@ -19,6 +19,28 @@ vi.mock('../../../rename-modal', () => ({ ), })) +// Mock Confirm component since it uses createPortal which can cause issues in tests +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, title, content, onConfirm, onCancel }: { + isShow: boolean + title: string + content?: React.ReactNode + onConfirm: () => void + onCancel: () => void + }) => ( + isShow + ? ( +
+
{title}
+
{content}
+ + +
+ ) + : null + ), +})) + describe('DatasetCardModals', () => { const mockDataset: DataSet = { id: 'dataset-1', @@ -172,11 +194,9 @@ describe('DatasetCardModals', () => { />, ) - // Find and click the confirm button - const confirmButton = screen.getByRole('button', { name: /confirm|ok|delete/i }) - || screen.getAllByRole('button').find(btn => btn.textContent?.toLowerCase().includes('confirm')) - if (confirmButton) - fireEvent.click(confirmButton) + // Find and click the confirm button using our mocked Confirm component + const confirmButton = screen.getByRole('button', { name: /confirm/i }) + fireEvent.click(confirmButton) expect(onConfirmDelete).toHaveBeenCalledTimes(1) }) diff --git a/web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx b/web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx new file mode 100644 index 0000000000..28085e52fa --- /dev/null +++ b/web/app/components/datasets/settings/form/components/basic-info-section.spec.tsx @@ -0,0 +1,441 @@ +import type { Member } from '@/models/common' +import type { DataSet, IconInfo } from '@/models/datasets' +import type { RetrievalConfig } from '@/types/app' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { IndexingType } from '../../../create/step-two' +import BasicInfoSection from './basic-info-section' + +// Mock app-context +vi.mock('@/context/app-context', () => ({ + useSelector: () => ({ + id: 'user-1', + name: 'Current User', + email: 'current@example.com', + avatar_url: '', + role: 'owner', + }), +})) + +// Mock image uploader hooks for AppIconPicker +vi.mock('@/app/components/base/image-uploader/hooks', () => ({ + useLocalFileUploader: () => ({ + disabled: false, + handleLocalFileUpload: vi.fn(), + }), + useImageFiles: () => ({ + files: [], + onUpload: vi.fn(), + onRemove: vi.fn(), + onReUpload: vi.fn(), + onImageLinkLoadError: vi.fn(), + onImageLinkLoadSuccess: vi.fn(), + onClear: vi.fn(), + }), +})) + +describe('BasicInfoSection', () => { + const mockDataset: DataSet = { + id: 'dataset-1', + name: 'Test Dataset', + description: 'Test description', + permission: DatasetPermission.onlyMe, + icon_info: { + icon_type: 'emoji', + icon: '📚', + icon_background: '#FFFFFF', + icon_url: '', + }, + indexing_technique: IndexingType.QUALIFIED, + indexing_status: 'completed', + data_source_type: DataSourceType.FILE, + doc_form: ChunkingMode.text, + embedding_model: 'text-embedding-ada-002', + embedding_model_provider: 'openai', + embedding_available: true, + app_count: 0, + document_count: 5, + total_document_count: 5, + word_count: 1000, + provider: 'vendor', + tags: [], + partial_member_list: [], + external_knowledge_info: { + external_knowledge_id: 'ext-1', + external_knowledge_api_id: 'api-1', + external_knowledge_api_name: 'External API', + external_knowledge_api_endpoint: 'https://api.example.com', + }, + external_retrieval_model: { + top_k: 3, + score_threshold: 0.7, + score_threshold_enabled: true, + }, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 3, + score_threshold_enabled: false, + score_threshold: 0.5, + } as RetrievalConfig, + built_in_field_enabled: false, + keyword_number: 10, + created_by: 'user-1', + updated_by: 'user-1', + updated_at: Date.now(), + runtime_mode: 'general', + enable_api: true, + is_multimodal: false, + } + + const mockMemberList: Member[] = [ + { id: 'user-1', name: 'User 1', email: 'user1@example.com', role: 'owner', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + { id: 'user-2', name: 'User 2', email: 'user2@example.com', role: 'admin', avatar: '', avatar_url: '', last_login_at: '', created_at: '', status: 'active' }, + ] + + const mockIconInfo: IconInfo = { + icon_type: 'emoji', + icon: '📚', + icon_background: '#FFFFFF', + icon_url: '', + } + + const defaultProps = { + currentDataset: mockDataset, + isCurrentWorkspaceDatasetOperator: false, + name: 'Test Dataset', + setName: vi.fn(), + description: 'Test description', + setDescription: vi.fn(), + iconInfo: mockIconInfo, + showAppIconPicker: false, + handleOpenAppIconPicker: vi.fn(), + handleSelectAppIcon: vi.fn(), + handleCloseAppIconPicker: vi.fn(), + permission: DatasetPermission.onlyMe, + setPermission: vi.fn(), + selectedMemberIDs: ['user-1'], + setSelectedMemberIDs: vi.fn(), + memberList: mockMemberList, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + }) + + it('should render name and icon section', () => { + render() + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + }) + + it('should render description section', () => { + render() + expect(screen.getByText(/form\.desc/i)).toBeInTheDocument() + }) + + it('should render permissions section', () => { + render() + // Use exact match to avoid matching "permissionsOnlyMe" + expect(screen.getByText('datasetSettings.form.permissions')).toBeInTheDocument() + }) + + it('should render name input with correct value', () => { + render() + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).toBeInTheDocument() + }) + + it('should render description textarea with correct value', () => { + render() + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea).toBeInTheDocument() + }) + + it('should render app icon with emoji', () => { + const { container } = render() + // The icon section should be rendered (emoji may be in a span or SVG) + const iconSection = container.querySelector('[class*="cursor-pointer"]') + expect(iconSection).toBeInTheDocument() + }) + }) + + describe('Name Input', () => { + it('should call setName when name input changes', () => { + const setName = vi.fn() + render() + + const nameInput = screen.getByDisplayValue('Test Dataset') + fireEvent.change(nameInput, { target: { value: 'New Name' } }) + + expect(setName).toHaveBeenCalledWith('New Name') + }) + + it('should disable name input when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + render() + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).toBeDisabled() + }) + + it('should enable name input when embedding is available', () => { + render() + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput).not.toBeDisabled() + }) + + it('should display empty name', () => { + const { container } = render() + + // Find the name input by its structure - may be type=text or just input + const nameInput = container.querySelector('input') + expect(nameInput).toHaveValue('') + }) + }) + + describe('Description Textarea', () => { + it('should call setDescription when description changes', () => { + const setDescription = vi.fn() + render() + + const descriptionTextarea = screen.getByDisplayValue('Test description') + fireEvent.change(descriptionTextarea, { target: { value: 'New Description' } }) + + expect(setDescription).toHaveBeenCalledWith('New Description') + }) + + it('should disable description textarea when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + render() + + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea).toBeDisabled() + }) + + it('should render placeholder', () => { + render() + + const descriptionTextarea = screen.getByPlaceholderText(/form\.descPlaceholder/i) + expect(descriptionTextarea).toBeInTheDocument() + }) + }) + + describe('App Icon', () => { + it('should call handleOpenAppIconPicker when icon is clicked', () => { + const handleOpenAppIconPicker = vi.fn() + const { container } = render() + + // Find the clickable icon element - it's inside a wrapper that handles the click + const iconWrapper = container.querySelector('[class*="cursor-pointer"]') + if (iconWrapper) { + fireEvent.click(iconWrapper) + expect(handleOpenAppIconPicker).toHaveBeenCalled() + } + }) + + it('should render AppIconPicker when showAppIconPicker is true', () => { + const { baseElement } = render() + + // AppIconPicker renders a modal with emoji tabs and options via portal + // We just verify the component renders without crashing when picker is shown + expect(baseElement).toBeInTheDocument() + }) + + it('should not render AppIconPicker when showAppIconPicker is false', () => { + const { container } = render() + + // Check that AppIconPicker is not rendered + expect(container.querySelector('[data-testid="app-icon-picker"]')).not.toBeInTheDocument() + }) + + it('should render image icon when icon_type is image', () => { + const imageIconInfo: IconInfo = { + icon_type: 'image', + icon: 'file-123', + icon_background: undefined, + icon_url: 'https://example.com/icon.png', + } + render() + + // For image type, it renders an img element + const img = screen.queryByRole('img') + if (img) { + expect(img).toHaveAttribute('src', expect.stringContaining('icon.png')) + } + }) + }) + + describe('Permission Selector', () => { + it('should render with correct permission value', () => { + render() + + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + }) + + it('should render all team members permission', () => { + render() + + expect(screen.getByText(/form\.permissionsAllMember/i)).toBeInTheDocument() + }) + + it('should be disabled when embedding is not available', () => { + const datasetWithoutEmbedding = { ...mockDataset, embedding_available: false } + const { container } = render( + , + ) + + // Check for disabled state via cursor-not-allowed class + const disabledElement = container.querySelector('[class*="cursor-not-allowed"]') + expect(disabledElement).toBeInTheDocument() + }) + + it('should be disabled when user is dataset operator', () => { + const { container } = render( + , + ) + + const disabledElement = container.querySelector('[class*="cursor-not-allowed"]') + expect(disabledElement).toBeInTheDocument() + }) + + it('should call setPermission when permission changes', async () => { + const setPermission = vi.fn() + render() + + // Open dropdown + const trigger = screen.getByText(/form\.permissionsOnlyMe/i) + fireEvent.click(trigger) + + await waitFor(() => { + // Click All Team Members option + const allMemberOptions = screen.getAllByText(/form\.permissionsAllMember/i) + fireEvent.click(allMemberOptions[0]) + }) + + expect(setPermission).toHaveBeenCalledWith(DatasetPermission.allTeamMembers) + }) + + it('should call setSelectedMemberIDs when members are selected', async () => { + const setSelectedMemberIDs = vi.fn() + const { container } = render( + , + ) + + // For partial members permission, the member selector should be visible + // The exact interaction depends on the MemberSelector component + // We verify the component renders without crashing + expect(container).toBeInTheDocument() + }) + }) + + describe('Undefined Dataset', () => { + it('should handle undefined currentDataset gracefully', () => { + render() + + // Should still render but inputs might behave differently + expect(screen.getByText(/form\.nameAndIcon/i)).toBeInTheDocument() + }) + }) + + describe('Props Validation', () => { + it('should update when name prop changes', () => { + const { rerender } = render() + + expect(screen.getByDisplayValue('Initial Name')).toBeInTheDocument() + + rerender() + + expect(screen.getByDisplayValue('Updated Name')).toBeInTheDocument() + }) + + it('should update when description prop changes', () => { + const { rerender } = render() + + expect(screen.getByDisplayValue('Initial Description')).toBeInTheDocument() + + rerender() + + expect(screen.getByDisplayValue('Updated Description')).toBeInTheDocument() + }) + + it('should update when permission prop changes', () => { + const { rerender } = render() + + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + + rerender() + + expect(screen.getByText(/form\.permissionsAllMember/i)).toBeInTheDocument() + }) + }) + + describe('Member List', () => { + it('should pass member list to PermissionSelector', () => { + const { container } = render( + , + ) + + // For partial members, a member selector component should be rendered + // We verify it renders without crashing + expect(container).toBeInTheDocument() + }) + + it('should handle empty member list', () => { + render( + , + ) + + expect(screen.getByText(/form\.permissionsOnlyMe/i)).toBeInTheDocument() + }) + }) + + describe('Accessibility', () => { + it('should have accessible name input', () => { + render() + + const nameInput = screen.getByDisplayValue('Test Dataset') + expect(nameInput.tagName.toLowerCase()).toBe('input') + }) + + it('should have accessible description textarea', () => { + render() + + const descriptionTextarea = screen.getByDisplayValue('Test description') + expect(descriptionTextarea.tagName.toLowerCase()).toBe('textarea') + }) + }) +}) diff --git a/web/app/components/datasets/settings/form/components/basic-info-section.tsx b/web/app/components/datasets/settings/form/components/basic-info-section.tsx new file mode 100644 index 0000000000..3d3cf75851 --- /dev/null +++ b/web/app/components/datasets/settings/form/components/basic-info-section.tsx @@ -0,0 +1,124 @@ +'use client' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import type { Member } from '@/models/common' +import type { DataSet, DatasetPermission, IconInfo } from '@/models/datasets' +import type { AppIconType } from '@/types/app' +import { useTranslation } from 'react-i18next' +import AppIcon from '@/app/components/base/app-icon' +import AppIconPicker from '@/app/components/base/app-icon-picker' +import Input from '@/app/components/base/input' +import Textarea from '@/app/components/base/textarea' +import PermissionSelector from '../../permission-selector' + +const rowClass = 'flex gap-x-1' +const labelClass = 'flex items-center shrink-0 w-[180px] h-7 pt-1' + +type BasicInfoSectionProps = { + currentDataset: DataSet | undefined + isCurrentWorkspaceDatasetOperator: boolean + name: string + setName: (value: string) => void + description: string + setDescription: (value: string) => void + iconInfo: IconInfo + showAppIconPicker: boolean + handleOpenAppIconPicker: () => void + handleSelectAppIcon: (icon: AppIconSelection) => void + handleCloseAppIconPicker: () => void + permission: DatasetPermission | undefined + setPermission: (value: DatasetPermission | undefined) => void + selectedMemberIDs: string[] + setSelectedMemberIDs: (value: string[]) => void + memberList: Member[] +} + +const BasicInfoSection = ({ + currentDataset, + isCurrentWorkspaceDatasetOperator, + name, + setName, + description, + setDescription, + iconInfo, + showAppIconPicker, + handleOpenAppIconPicker, + handleSelectAppIcon, + handleCloseAppIconPicker, + permission, + setPermission, + selectedMemberIDs, + setSelectedMemberIDs, + memberList, +}: BasicInfoSectionProps) => { + const { t } = useTranslation() + + return ( + <> + {/* Dataset name and icon */} +
+
+
{t('form.nameAndIcon', { ns: 'datasetSettings' })}
+
+
+ + setName(e.target.value)} + /> +
+
+ + {/* Dataset description */} +
+
+
{t('form.desc', { ns: 'datasetSettings' })}
+
+
+