diff --git a/api/.env.example b/api/.env.example index d5d4e4486f..5a13ac69de 100644 --- a/api/.env.example +++ b/api/.env.example @@ -65,7 +65,7 @@ OPENDAL_FS_ROOT=storage # S3 Storage configuration S3_USE_AWS_MANAGED_IAM=false -S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com +S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key @@ -74,7 +74,7 @@ S3_REGION=your-region # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key -AZURE_BLOB_CONTAINER_NAME=yout-container-name +AZURE_BLOB_CONTAINER_NAME=your-container-name AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Aliyun oss Storage configuration @@ -88,7 +88,7 @@ ALIYUN_OSS_REGION=your-region ALIYUN_OSS_PATH=your-path # Google Storage configuration -GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name +GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string # Tencent COS Storage configuration diff --git a/api/.ruff.toml b/api/.ruff.toml index 26a1b977a9..f30275a943 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -67,7 +67,7 @@ ignore = [ "SIM105", # suppressible-exception "SIM107", # return-in-try-except-finally "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # eumerate-for-loop + "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false ] diff --git a/api/commands.py b/api/commands.py index c6e450b3ee..76c8d3e382 100644 --- a/api/commands.py +++ b/api/commands.py @@ -563,8 +563,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( @@ -584,7 +589,7 @@ def upgrade_db(): click.echo(click.style("Starting database migration.", fg="green")) # run db migration - import flask_migrate + import flask_migrate # type: ignore flask_migrate.upgrade() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index fcecb346b0..5865ddcc8b 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -659,7 +659,7 @@ class RagEtlConfig(BaseSettings): UNSTRUCTURED_API_KEY: Optional[str] = Field( description="API key for Unstructured.io service", - default=None, + default="", ) SCARF_NO_ANALYTICS: Optional[str] = Field( diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 89a638ae54..30cd93a010 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -232,7 +232,7 @@ class DataSourceNotionApi(Resource): args["doc_form"], args["doc_language"], ) - return response, 200 + return response.model_dump(), 200 class DataSourceNotionDatasetSyncApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f3c3736b25..0c0d2e2003 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource): except Exception as e: raise IndexingEstimateError(str(e)) - return response, 200 + return response.model_dump(), 200 class DatasetRelatedAppListApi(Resource): @@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource): }, 200 +class DatasetAutoDisableLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 + + api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") api.add_resource(DatasetUseCheckApi, "/datasets//use-check") @@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") +api.add_resource(DatasetAutoDisableLogApi, "/datasets//auto-disable-logs") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c236e1a431..3c132bc3d0 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -52,6 +52,7 @@ from fields.document_fields import ( from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -267,20 +268,22 @@ class DatasetDocumentListApi(Resource): parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument( "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() + knowledge_config = KnowledgeConfig(**args) - if not dataset.indexing_technique and not args["indexing_technique"]: + if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: - documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -290,6 +293,25 @@ class DatasetDocumentListApi(Resource): return {"documents": documents, "batch": batch} + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id): + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + + try: + document_ids = request.args.getlist("document_id") + DocumentService.delete_documents(dataset, document_ids) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return {"result": "success"}, 204 + class DatasetInitApi(Resource): @setup_required @@ -325,9 +347,9 @@ class DatasetInitApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - - if args["indexing_technique"] == "high_quality": - if args["embedding_model"] is None or args["embedding_model_provider"] is None: + knowledge_config = KnowledgeConfig(**args) + if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() @@ -346,11 +368,11 @@ class DatasetInitApi(Resource): raise ProviderNotInitializeError(ex.description) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, document_data=args, account=current_user + tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -403,7 +425,7 @@ class DocumentIndexingEstimateApi(DocumentResource): indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate( + estimate_response = indexing_runner.indexing_estimate( current_user.current_tenant_id, [extract_setting], data_process_rule_dict, @@ -411,6 +433,7 @@ class DocumentIndexingEstimateApi(DocumentResource): "English", dataset_id, ) + return estimate_response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -423,7 +446,7 @@ class DocumentIndexingEstimateApi(DocumentResource): except Exception as e: raise IndexingEstimateError(str(e)) - return response + return response, 200 class DocumentBatchIndexingEstimateApi(DocumentResource): @@ -434,9 +457,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: - return response + return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() info_list = [] @@ -514,6 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): "English", dataset_id, ) + return response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -525,7 +548,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response class DocumentBatchIndexingStatusApi(DocumentResource): @@ -598,7 +620,8 @@ class DocumentDetailApi(DocumentResource): if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} elif metadata == "without": - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -606,7 +629,8 @@ class DocumentDetailApi(DocumentResource): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -629,7 +653,8 @@ class DocumentDetailApi(DocumentResource): "doc_language": document.doc_language, } else: - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -637,7 +662,8 @@ class DocumentDetailApi(DocumentResource): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -773,9 +799,8 @@ class DocumentStatusApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, action): dataset_id = str(dataset_id) - document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -790,84 +815,79 @@ class DocumentStatusApi(DocumentResource): # check user's permission DatasetService.check_dataset_permission(dataset, current_user) - document = self.get_document(dataset_id, document_id) + document_ids = request.args.getlist("document_id") + for document_id in document_ids: + document = self.get_document(dataset_id, document_id) - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Document is being indexed, please try again later") + indexing_cache_key = "document_{}_indexing".format(document.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") - if action == "enable": - if document.enabled: - raise InvalidActionError("Document already enabled.") + if action == "enable": + if document.enabled: + continue + document.enabled = True + document.disabled_at = None + document.disabled_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + add_document_to_index_task.delay(document_id) - add_document_to_index_task.delay(document_id) + elif action == "disable": + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError(f"Document: {document.name} is not completed.") + if not document.enabled: + continue - return {"result": "success"}, 200 + document.enabled = False + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) + document.disabled_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError("Document is not completed.") - if not document.enabled: - raise InvalidActionError("Document already disabled.") - - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - return {"result": "success"}, 200 - - elif action == "archive": - if document.archived: - raise InvalidActionError("Document already archived.") - - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - if document.enabled: # Set cache to prevent indexing the same document multiple times redis_client.setex(indexing_cache_key, 600, 1) remove_document_from_index_task.delay(document_id) - return {"result": "success"}, 200 - elif action == "un_archive": - if not document.archived: - raise InvalidActionError("Document is not archived.") + elif action == "archive": + if document.archived: + continue - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + document.archived = True + document.archived_at = datetime.now(UTC).replace(tzinfo=None) + document.archived_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + if document.enabled: + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - add_document_to_index_task.delay(document_id) + remove_document_from_index_task.delay(document_id) - return {"result": "success"}, 200 - else: - raise InvalidActionError() + elif action == "un_archive": + if not document.archived: + continue + document.archived = False + document.archived_at = None + document.archived_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + add_document_to_index_task.delay(document_id) + + else: + raise InvalidActionError() + return {"result": "success"}, 200 class DocumentPauseApi(DocumentResource): @@ -1038,7 +1058,7 @@ api.add_resource( ) api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") -api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") api.add_resource(DocumentRetryApi, "/datasets//retry") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2d5933ca23..96654c09fd 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,4 @@ import uuid -from datetime import UTC, datetime import pandas as pd from flask import request @@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError -from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, + NoFileUploadedError, + TooManyFilesError, +) from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, @@ -20,15 +25,15 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.segment_fields import segment_fields +from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required -from models import DocumentSegment +from models.dataset import ChildChunk, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task -from tasks.disable_segment_from_index_task import disable_segment_from_index_task -from tasks.enable_segment_to_index_task import enable_segment_to_index_task class DatasetDocumentSegmentListApi(Resource): @@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource): raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument("last_id", type=str, default=None, location="args") parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + args = parser.parse_args() - last_id = args["last_id"] + page = args["page"] limit = min(args["limit"], 100) status_list = args["status"] hit_count_gte = args["hit_count_gte"] @@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource): query = DocumentSegment.query.filter( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ) - - if last_id is not None: - last_segment = db.session.get(DocumentSegment, str(last_id)) - if last_segment: - query = query.filter(DocumentSegment.position > last_segment.position) - else: - return {"data": [], "has_more": False, "limit": limit}, 200 + ).order_by(DocumentSegment.position.asc()) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - total = query.count() - segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() + segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) - has_more = False - if len(segments) > limit: - has_more = True - segments = segments[:-1] - - return { - "data": marshal(segments, segment_fields), - "doc_form": document.doc_form, - "has_more": has_more, + response = { + "data": marshal(segments.items, segment_fields), "limit": limit, - "total": total, - }, 200 + "total": segments.total, + "total_pages": segments.pages, + "page": page, + } + return response, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + segment_ids = request.args.getlist("segment_id") + + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + SegmentService.delete_segments(segment_ids, document, dataset) + return {"result": "success"}, 200 class DatasetDocumentSegmentApi(Resource): @@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, segment_id, action): + def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) + segment_ids = request.args.getlist("segment_id") - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() - - if not segment: - raise NotFound("Segment not found.") - - if segment.status != "completed": - raise NotFound("Segment is not completed, enable or disable function is not allowed") - - document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - - indexing_cache_key = "segment_{}_indexing".format(segment.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Segment is being indexed, please try again later") - - if action == "enable": - if segment.enabled: - raise InvalidActionError("Segment is already enabled.") - - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - enable_segment_to_index_task.delay(segment.id) - - return {"result": "success"}, 200 - elif action == "disable": - if not segment.enabled: - raise InvalidActionError("Segment is already disabled.") - - segment.enabled = False - segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) - segment.disabled_by = current_user.id - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - disable_segment_from_index_task.delay(segment.id) - - return {"result": "success"}, 200 - else: - raise InvalidActionError() + try: + SegmentService.update_segments_status(segment_ids, action, dataset, document) + except Exception as e: + raise InvalidActionError(str(e)) + return {"result": "success"}, 200 class DatasetDocumentSegmentAddApi(Resource): @@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") + parser.add_argument( + "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" + ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(args, segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource): return {"job_id": job_id, "job_status": cache_result.decode()}, 200 +class ChildChunkAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") + def post(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + if not current_user.is_editor: + raise Forbidden() + # check embedding model setting + if dataset.indexing_technique == "high_quality": + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + parser = reqparse.RequestParser() + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + + args = parser.parse_args() + + page = args["page"] + limit = min(args["limit"], 100) + keyword = args["keyword"] + + child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) + return { + "data": marshal(child_chunks.items, child_chunk_fields), + "total": child_chunks.total, + "total_pages": child_chunks.pages, + "page": page, + "limit": limit, + }, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunks, child_chunk_fields)}, 200 + + +class ChildChunkUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + try: + SegmentService.delete_child_chunk(child_chunk, dataset) + except ChildChunkDeleteIndexServiceError as e: + raise ChildChunkDeleteIndexError(str(e)) + return {"result": "success"}, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.update_child_chunk( + args.get("content"), child_chunk, segment, document, dataset + ) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") -api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource( + DatasetDocumentSegmentApi, "/datasets//documents//segment/" +) api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") api.add_resource( DatasetDocumentSegmentUpdateApi, @@ -424,3 +651,11 @@ api.add_resource( "/datasets//documents//segments/batch_import", "/datasets/batch_import_status/", ) +api.add_resource( + ChildChunkAddApi, + "/datasets//documents//segments//child_chunks", +) +api.add_resource( + ChildChunkUpdateApi, + "/datasets//documents//segments//child_chunks/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 6a7a3971a8..2f00a84de6 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException): error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 + + +class ChildChunkIndexingError(BaseHTTPException): + error_code = "child_chunk_indexing_error" + description = "Create child chunk index failed: {message}" + code = 500 + + +class ChildChunkDeleteIndexError(BaseHTTPException): + error_code = "child_chunk_delete_index_error" + description = "Delete child chunk index failed: {message}" + code = 500 diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index c3488de299..405d5ed607 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -66,10 +66,17 @@ class MessageFeedbackApi(InstalledAppResource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=current_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 522c7509b9..773ea0e0c6 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -108,7 +108,13 @@ class MessageFeedbackApi(Resource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 34afe2837f..ea664b8f1b 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -8,12 +8,16 @@ from werkzeug.exceptions import NotFound import services.dataset_service from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api -from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.app.error import ( + FileTooLargeError, + NoFileUploadedError, + ProviderNotInitializeError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, - NoFileUploadedError, - TooManyFilesError, ) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError @@ -22,6 +26,7 @@ from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -67,13 +72,14 @@ class DocumentAddByTextApi(DatasetApiResource): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source + knowledge_config = KnowledgeConfig(**args) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -122,12 +128,13 @@ class DocumentUpdateByTextApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -186,12 +193,13 @@ class DocumentAddByFileApi(DatasetApiResource): data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -234,23 +242,30 @@ class DocumentUpdateByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - upload_file = FileService.upload_file( - filename=file.filename, - content=file.read(), - mimetype=file.mimetype, - user=current_user, - source="datasets", - ) + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 34904574a8..1c500f51bf 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -16,6 +16,7 @@ from extensions.ext_database import db from fields.segment_fields import segment_fields from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs class SegmentApi(DatasetApiResource): @@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource): args = parser.parse_args() SegmentService.segment_create_args_validate(args["segment"], document) - segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 0f47e64370..2afc11f601 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -105,10 +105,17 @@ class MessageFeedbackApi(WebApiResource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json", default=None) args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 684f2bc8a3..cb2a361f17 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -393,7 +393,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2e9b643d8b..ab0f0763f4 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping from threading import Thread from typing import Any, Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -66,7 +69,6 @@ from models.enums import CreatedByRole from models.workflow import ( Workflow, WorkflowNodeExecution, - WorkflowRun, WorkflowRunStatus, ) @@ -80,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow: Workflow - _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _conversation_name_generate_thread: Optional[Thread] = None @@ -97,32 +97,37 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream: bool, dialogue_count: int, ) -> None: - """ - Initialize AdvancedChatAppGenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - :param dialogue_count: dialogue count - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise NotImplementedError(f"User type not supported: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) - self._workflow = workflow - self._conversation = conversation - self._message = message self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, @@ -135,19 +140,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] + self._workflow_run_id = "" def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -173,12 +175,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return ChatbotAppBlockingResponse( task_id=stream_response.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=self._task_state.answer, - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -196,9 +198,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ for stream_response in generator: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -216,7 +218,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -268,7 +270,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ # init fake graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None - workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -276,237 +277,303 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + self._workflow_run_id = workflow_run.id + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + message.workflow_run_id = workflow_run.id + workflow_start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() - self._refetch_message() - self._message.workflow_run_id = workflow_run.id - - db.session.commit() - db.session.refresh(self._message) - db.session.close() - - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + yield workflow_start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + node_retry_resp = self._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - if response: - yield response + if node_retry_resp: + yield node_retry_resp elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) - response_start = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + node_start_resp = self._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - if response_start: - yield response_start + if node_start_resp: + yield node_start_resp elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response_finish = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) - if response_finish: - yield response_finish + node_finish_resp = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) - response_finish = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - - if response_finish: - yield response_finish + node_finish_resp = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("workflow run not initialized.") - - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) - elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) - elif isinstance(event, QueueWorkflowFailedEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - elif isinstance(event, QueueStopEvent): - if workflow_run and graph_runtime_state: - workflow_run = self._handle_workflow_run_failed( + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, + event=event, + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_start_resp = self._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_next_resp = self._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + if not graph_runtime_state: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), - conversation_id=self._conversation.id, + outputs=event.outputs, + conversation_id=self._conversation_id, trace_manager=trace_manager, ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + session.commit() - # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + yield workflow_finish_resp + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count, + ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err = self._handle_error(event=err_event, session=session, message_id=self._message_id) + session.commit() + + yield workflow_finish_resp + yield self._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent): + if self._workflow_run_id and graph_runtime_state: + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + + yield workflow_finish_resp yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -523,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._task_state.answer += delta_text yield self._message_to_stream_response( - answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation @@ -538,7 +605,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + with Session(db.engine) as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() yield self._message_end_to_stream_response() elif isinstance(event, QueueAgentLogEvent): @@ -553,54 +622,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - self._refetch_message() - - self._message.answer = self._task_state.answer - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = ( + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + message = self._get_message(session=session) + message.answer = self._task_state.answer + message.provider_response_latency = time.perf_counter() - self._start_at + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) message_files = [ MessageFile( - message_id=self._message.id, + message_id=message.id, type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], created_by_role=CreatedByRole.ACCOUNT - if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else CreatedByRole.END_USER, - created_by=self._message.from_account_id or self._message.from_end_user_id or "", + created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files ] - db.session.add_all(message_files) + session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.total_price = usage.total_price - self._message.currency = usage.currency - + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.total_price = usage.total_price + message.currency = usage.currency self._task_state.metadata["usage"] = jsonable_encoder(usage) else: self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) - - db.session.commit() - message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -617,7 +678,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, files=self._recorded_files, metadata=extras.get("metadata", {}), ) @@ -645,11 +706,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return False - def _refetch_message(self) -> None: - """ - Refetch message. - :return: - """ - message = db.session.query(Message).filter(Message.id == self._message.id).first() - if message: - self._message = message + def _get_message(self, *, session: Session): + stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(stmt) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + return message diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c2e35faf89..4e3aa840ce 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -70,14 +70,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): queue_manager=queue_manager, conversation=conversation, message=message, - user=user, stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 9a5f90f998..cbfb535848 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -325,7 +325,7 @@ class WorkflowAppGenerator(BaseAppGenerator): try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index af0698d701..df48a83316 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,6 +3,8 @@ import time from collections.abc import Generator from typing import Any, Optional, Union +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager @@ -51,6 +53,7 @@ from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import EndUser from models.workflow import ( Workflow, @@ -69,8 +72,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] @@ -84,44 +85,42 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param user: user - :param stream: is streamed - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict - self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, } self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} - self._wip_workflow_agent_logs = {} - self.total_tokens: int = 0 + self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -188,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -237,7 +236,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa :return: """ graph_runtime_state = None - workflow_run = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -245,180 +243,261 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event) + err = self._handle_error(event=event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + self._workflow_run_id = workflow_run.id + start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + response = self._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if response: yield response elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - - node_start_response = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) + node_start_response = self._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - - node_success_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) + node_success_response = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_success_response: yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_failed( + session=session, + event=event, + ) + node_failed_response = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - node_failed_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) if node_failed_response: yield node_failed_response elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_start_resp = self._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_next_resp = self._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -440,7 +519,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: + def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: @@ -462,12 +541,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" - workflow_app_log.created_by = self._user.id + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id - db.session.add(workflow_app_log) - db.session.commit() - db.session.close() + session.add(workflow_app_log) def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 03a81353d0..e363a7f642 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,9 @@ import logging import time -from typing import Optional, Union +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ( @@ -17,9 +20,7 @@ from core.app.entities.task_entities import ( from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, Message +from models.model import Message logger = logging.getLogger(__name__) @@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline: self, application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, - user: Union[Account, EndUser], stream: bool, ) -> None: """ @@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline: """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager - self._user = user self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): - """ - Handle error event. - :param event: event - :param message: message - :return: - """ + def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline: else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) - if message: - refetch_message = db.session.query(Message).filter(Message.id == message.id).first() + if not message_id or not session: + return err - if refetch_message: - err_desc = self._error_to_desc(err) - refetch_message.status = "error" - refetch_message.error = err_desc - - db.session.commit() + stmt = select(Message).where(Message.id == message_id) + message = session.scalar(stmt) + if not message: + return err + err_desc = self._error_to_desc(err) + message.status = "error" + message.error = err_desc return err def _error_to_desc(self, e: Exception) -> str: diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b9f8e7ca56..c84f8ba3e4 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -5,6 +5,9 @@ from collections.abc import Generator from threading import Thread from typing import Optional, Union, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.account import Account -from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan queue_manager: AppQueueManager, conversation: Conversation, message: Message, - user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) self._model_config = application_generate_entity.model_conf self._app_config = application_generate_entity.app_config - self._conversation = conversation - self._message = message + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) self._task_state = EasyUITaskState( llm_result=LLMResult( @@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan CompletionAppBlockingResponse, Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._conversation) - db.session.refresh(self._message) - db.session.close() - if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query or "" + conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._task_state.metadata: extras["metadata"] = self._task_state.metadata response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation.mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan response = ChatbotAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan for stream_response in generator: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): yield CompletionAppStreamResponse( - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = output_moderation_answer yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message(trace_manager) - - yield self._message_end_to_stream_response() + with Session(db.engine) as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + session.commit() + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): @@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) else: - yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan llm_result = self._task_state.llm_result usage = llm_result.usage - message = db.session.query(Message).filter(Message.id == self._message.id).first() + message_stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(message_stmt) if not message: - raise Exception(f"Message {self._message.id} not found") - self._message = message - conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + raise ValueError(f"message {self._message_id} not found") + conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) + conversation = session.scalar(conversation_stmt) if not conversation: - raise Exception(f"Conversation {self._conversation.id} not found") - self._conversation = conversation + raise ValueError(f"Conversation {self._conversation_id} not found") - self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages ) - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = ( + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer = ( PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.total_price = usage.total_price - self._message.currency = usage.currency - self._message.message_metadata = ( + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.provider_response_latency = time.perf_counter() - self._start_at + message.total_price = usage.total_price + message.currency = usage.currency + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) - db.session.commit() - if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id ) ) message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} - and hasattr(self._application_generate_entity, "conversation_id") - and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, metadata=extras.get("metadata", {}), ) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 007543f6d0..15f2c25c66 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -36,7 +36,7 @@ class MessageCycleManage: ] _task_state: Union[EasyUITaskState, WorkflowTaskState] - def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation: conversation @@ -56,7 +56,7 @@ class MessageCycleManage: target=self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "conversation_id": conversation.id, + "conversation_id": conversation_id, "query": query, }, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 2d958b460d..1a2e67f7e7 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -5,6 +5,7 @@ from datetime import UTC, datetime from typing import Any, Optional, Union, cast from uuid import uuid4 +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -47,7 +48,6 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -65,28 +65,33 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - _wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]] - def _handle_workflow_run_start(self) -> WorkflowRun: - max_sequence = ( - db.session.query(db.func.max(WorkflowRun.sequence_number)) - .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) - .filter(WorkflowRun.app_id == self._workflow.app_id) - .scalar() - or 0 + def _handle_workflow_run_start( + self, + *, + session: Session, + workflow_id: str, + user_id: str, + created_by_role: CreatedByRole, + ) -> WorkflowRun: + workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.scalar(workflow_stmt) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + + max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( + WorkflowRun.tenant_id == workflow.tenant_id, + WorkflowRun.app_id == workflow.app_id, ) + max_sequence = session.scalar(max_sequence_stmt) or 0 new_sequence_number = max_sequence + 1 inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": continue - inputs[f"sys.{key.value}"] = value triggered_from = ( @@ -99,34 +104,33 @@ class WorkflowCycleManage: inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = WorkflowRun() - system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - workflow_run.id = system_id or str(uuid4()) - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER - ) - workflow_run.created_by = self._user.id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) - session.add(workflow_run) - session.commit() + workflow_run = WorkflowRun() + workflow_run.id = workflow_run_id + workflow_run.tenant_id = workflow.tenant_id + workflow_run.app_id = workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = workflow.id + workflow_run.type = workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = workflow.version + workflow_run.graph = workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = created_by_role + workflow_run.created_by = user_id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) return workflow_run def _handle_workflow_run_success( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -144,7 +148,7 @@ class WorkflowCycleManage: :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(outputs) @@ -155,9 +159,6 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - db.session.refresh(workflow_run) - if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -168,13 +169,13 @@ class WorkflowCycleManage: ) ) - db.session.close() - return workflow_run def _handle_workflow_run_partial_success( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -183,18 +184,7 @@ class WorkflowCycleManage: conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: - """ - Workflow run success - :param workflow_run: workflow run - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :param conversation_id: conversation id - :return: - """ - workflow_run = self._refetch_workflow_run(workflow_run.id) - + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value @@ -204,8 +194,6 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - db.session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -217,13 +205,13 @@ class WorkflowCycleManage: ) ) - db.session.close() - return workflow_run def _handle_workflow_run_failed( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -243,7 +231,7 @@ class WorkflowCycleManage: :param error: error message :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) workflow_run.status = status.value workflow_run.error = error @@ -252,21 +240,18 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - running_workflow_node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, - ) - .all() + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) + running_workflow_node_executions = session.scalars(stmt).all() + for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error @@ -274,13 +259,6 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() - db.session.commit() - - db.session.close() - - # with Session(db.engine, expire_on_commit=False) as session: - # session.add(workflow_run) - # session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -295,79 +273,49 @@ class WorkflowCycleManage: return workflow_run def _handle_node_execution_start( - self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: - # init workflow node execution + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = event.node_execution_id + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - } - ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_node_execution) - session.commit() - session.refresh(workflow_node_execution) - - self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution + session.add(workflow_node_execution) return workflow_node_execution - def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param event: queue node succeeded event - :return: - """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) - + def _handle_workflow_node_execution_success( + self, *, session: Session, event: QueueNodeSucceededEvent + ) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) execution_metadata_dict = dict(event.execution_metadata or {}) - if self._wip_workflow_agent_logs.get(workflow_node_execution.id): - if not execution_metadata_dict: - execution_metadata_dict = {} - - execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get( - workflow_node_execution.id, [] - ) - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.execution_metadata: execution_metadata, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - } - ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value @@ -378,54 +326,31 @@ class WorkflowCycleManage: workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent + self, + *, + session: Session, + event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent, ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - execution_metadata_dict = dict(event.execution_metadata or {}) - if self._wip_workflow_agent_logs.get(workflow_node_execution.id): - if not execution_metadata_dict: - execution_metadata_dict = {} - - execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get( - workflow_node_execution.id, [] - ) - - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: ( - WorkflowNodeExecutionStatus.FAILED.value - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value - ), - WorkflowNodeExecution.error: event.error, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - WorkflowNodeExecution.execution_metadata: execution_metadata, - } + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = ( WorkflowNodeExecutionStatus.FAILED.value @@ -440,12 +365,10 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - return workflow_node_execution def _handle_workflow_node_execution_retried( - self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -469,6 +392,7 @@ class WorkflowCycleManage: execution_metadata = json.dumps(merged_metadata) workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = event.node_execution_id workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.app_id = workflow_run.app_id workflow_node_execution.workflow_id = workflow_run.workflow_id @@ -491,10 +415,7 @@ class WorkflowCycleManage: workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - + session.add(workflow_node_execution) return workflow_node_execution ################################################# @@ -502,14 +423,14 @@ class WorkflowCycleManage: ################################################# def _workflow_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - """ - Workflow start to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return WorkflowStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -523,36 +444,32 @@ class WorkflowCycleManage: ) def _workflow_finish_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: - """ - Workflow finish to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ - # Attach WorkflowRun to an active session so "created_by_role" can be accessed. - workflow_run = db.session.merge(workflow_run) - - # Refresh to ensure any expired attributes are fully loaded - db.session.refresh(workflow_run) - created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: - created_by_account = workflow_run.created_by_account - if created_by_account: + if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + "id": account.id, + "name": account.name, + "email": account.email, + } + elif workflow_run.created_by_role == CreatedByRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: + created_by = { + "id": end_user.id, + "user": end_user.session_id, } else: - created_by_end_user = workflow_run.created_by_end_user - if created_by_end_user: - created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, - } + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") return WorkflowFinishStreamResponse( task_id=task_id, @@ -576,17 +493,20 @@ class WorkflowCycleManage: ) def _workflow_node_start_to_stream_response( - self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + self, + *, + session: Session, + event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - """ - Workflow node start to stream response. - :param event: queue node started event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None response = NodeStartStreamResponse( task_id=task_id, @@ -622,6 +542,8 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, + *, + session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -629,15 +551,14 @@ class WorkflowCycleManage: task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeFinishStreamResponse( task_id=task_id, @@ -669,19 +590,20 @@ class WorkflowCycleManage: def _workflow_node_retry_to_stream_response( self, + *, + session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeRetryStreamResponse( task_id=task_id, @@ -713,15 +635,10 @@ class WorkflowCycleManage: ) def _workflow_parallel_branch_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - """ - Workflow parallel branch start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run started event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -737,17 +654,14 @@ class WorkflowCycleManage: def _workflow_parallel_branch_finished_to_stream_response( self, + *, + session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - """ - Workflow parallel branch finished to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run succeeded or failed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -764,15 +678,10 @@ class WorkflowCycleManage: ) def _workflow_iteration_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - """ - Workflow iteration start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration start event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -791,15 +700,10 @@ class WorkflowCycleManage: ) def _workflow_iteration_next_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - """ - Workflow iteration next to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration next event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeNextStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -820,15 +724,10 @@ class WorkflowCycleManage: ) def _workflow_iteration_completed_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - """ - Workflow iteration completed to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration completed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -912,27 +811,22 @@ class WorkflowCycleManage: return None - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - + stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) if not workflow_run: raise WorkflowRunNotFoundError(workflow_run_id) return workflow_run - def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - """ - Refetch workflow node execution - :param node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) - + def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: + stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id) + workflow_node_execution = session.scalar(stmt) if not workflow_node_execution: raise WorkflowNodeExecutionNotFoundError(node_execution_id) @@ -945,41 +839,10 @@ class WorkflowCycleManage: :param event: agent log event :return: """ - node_execution = self._wip_workflow_node_executions.get(event.node_execution_id) - if not node_execution: - raise Exception(f"Workflow node execution not found: {event.node_execution_id}") - - node_execution_id = node_execution.id - original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, []) - - # try to find the log with the same id - for log in original_agent_logs: - if log.id == event.id: - # update the log - log.status = event.status - log.error = event.error - log.data = event.data - break - else: - # append the log - original_agent_logs.append( - AgentLogStreamResponse.Data( - id=event.id, - parent_id=event.parent_id, - node_execution_id=node_execution_id, - error=event.error, - status=event.status, - data=event.data, - label=event.label, - ) - ) - - self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs - return AgentLogStreamResponse( task_id=task_id, data=AgentLogStreamResponse.Data( - node_execution_id=node_execution_id, + node_execution_id=event.node_execution_id, id=event.id, parent_id=event.parent_id, label=event.label, diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py new file mode 100644 index 0000000000..90c9879733 --- /dev/null +++ b/api/core/entities/knowledge_entities.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pydantic import BaseModel + + +class PreviewDetail(BaseModel): + content: str + child_chunks: Optional[list[str]] = None + + +class QAPreviewDetail(BaseModel): + question: str + answer: str + + +class IndexingEstimate(BaseModel): + total_segments: int + preview: list[PreviewDetail] + qa_preview: Optional[list[QAPreviewDetail]] = None diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index eed2d7e49a..0261a6309e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -881,7 +881,7 @@ class ProviderConfiguration(BaseModel): # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] for model in provider_models: - if model.model_type == ModelType.LLM and m.model not in restrict_model_names: + if model.model_type == ModelType.LLM and model.model not in restrict_model_names: model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: model.status = ModelStatus.QUOTA_EXCEEDED diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 685dbc8ed4..da79b1bf03 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -8,34 +8,34 @@ import time import uuid from typing import Any, Optional, cast -from flask import Flask, current_app +from flask import current_app from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config +from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail from core.errors.error import ProviderTokenNotInitError -from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.rag_web_reader import get_image_upload_file_ids -from core.tools.utils.text_processing_utils import remove_leading_symbols from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper -from models.dataset import Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.feature_service import FeatureService @@ -115,6 +115,9 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + # delete child chunks + db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule processing_rule = ( @@ -183,7 +186,22 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = document_segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # build index @@ -222,7 +240,7 @@ class IndexingRunner: doc_language: str = "English", dataset_id: Optional[str] = None, indexing_technique: str = "economy", - ) -> dict: + ) -> IndexingEstimate: """ Estimate the indexing for the document. """ @@ -258,31 +276,38 @@ class IndexingRunner: tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts: list[str] = [] + preview_texts = [] # type: ignore + total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - all_text_docs = [] for extract_setting in extract_settings: # extract - text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) - all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._split_to_documents_for_estimate( - text_docs=text_docs, splitter=splitter, processing_rule=processing_rule + text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule.to_dict(), + tenant_id=current_user.current_tenant_id, + doc_language=doc_language, + preview=True, ) - total_segments += len(documents) for document in documents: - if len(preview_texts) < 5: - preview_texts.append(document.page_content) + if len(preview_texts) < 10: + if doc_form and doc_form == "qa_model": + preview_detail = QAPreviewDetail( + question=document.page_content, answer=document.metadata.get("answer") or "" + ) + preview_texts.append(preview_detail) + else: + preview_detail = PreviewDetail(content=document.page_content) # type: ignore + if document.children: + preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore + preview_texts.append(preview_detail) # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) @@ -299,15 +324,8 @@ class IndexingRunner: db.session.delete(image_file) if doc_form and doc_form == "qa_model": - if len(preview_texts) > 0: - # qa model document - response = LLMGenerator.generate_qa_document( - current_user.current_tenant_id, preview_texts[0], doc_language - ) - document_qa_list = self.format_split_text(response) - - return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} - return {"total_segments": total_segments, "preview": preview_texts} + return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict @@ -401,31 +419,26 @@ class IndexingRunner: @staticmethod def _get_splitter( - processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - character_splitter: TextSplitter - if processing_rule.mode == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = json.loads(processing_rule.rules) - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") - if segmentation.get("chunk_overlap"): - chunk_overlap = segmentation["chunk_overlap"] - else: - chunk_overlap = 0 - character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], + chunk_size=max_tokens, chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], @@ -441,143 +454,7 @@ class IndexingRunner: embedding_model_instance=embedding_model_instance, ) - return character_splitter - - def _step_split( - self, - text_docs: list[Document], - splitter: TextSplitter, - dataset: Dataset, - dataset_document: DatasetDocument, - processing_rule: DatasetProcessRule, - ) -> list[Document]: - """ - Split the text documents into documents and save them to the document segment. - """ - documents = self._split_to_documents( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule, - tenant_id=dataset.tenant_id, - document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language, - ) - - # save node to document segment - doc_store = DatasetDocumentStore( - dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id - ) - - # add document segments - doc_store.add_documents(documents) - - # update document status to indexing - cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - self._update_document_index_status( - document_id=dataset_document.id, - after_indexing_status="indexing", - extra_update_params={ - DatasetDocument.cleaning_completed_at: cur_time, - DatasetDocument.splitting_completed_at: cur_time, - }, - ) - - # update segment status to indexing - self._update_segments_by_document( - dataset_document_id=dataset_document.id, - update_params={ - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - }, - ) - - return documents - - def _split_to_documents( - self, - text_docs: list[Document], - splitter: TextSplitter, - processing_rule: DatasetProcessRule, - tenant_id: str, - document_form: str, - document_language: str, - ) -> list[Document]: - """ - Split the text documents into nodes. - """ - all_documents: list[Document] = [] - all_qa_documents: list[Document] = [] - for text_doc in text_docs: - # document clean - document_text = self._document_clean(text_doc.page_content, processing_rule) - text_doc.page_content = document_text - - # parse document to nodes - documents = splitter.split_documents([text_doc]) - split_documents = [] - for document_node in documents: - if document_node.page_content.strip(): - if document_node.metadata is not None: - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash - # delete Splitter character - page_content = document_node.page_content - document_node.page_content = remove_leading_symbols(page_content) - - if document_node.page_content: - split_documents.append(document_node) - all_documents.extend(split_documents) - # processing qa document - if document_form == "qa_model": - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self.format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "tenant_id": tenant_id, - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": document_language, - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() - return all_qa_documents - return all_documents - - def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): - format_documents = [] - if document_node.page_content is None or not document_node.page_content.strip(): - return - with flask_app.app_context(): - try: - # qa model document - response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) - document_qa_list = self.format_split_text(response) - qa_documents = [] - for result in document_qa_list: - qa_document = Document( - page_content=result["question"], metadata=document_node.metadata.model_copy() - ) - if qa_document.metadata is not None: - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash - qa_documents.append(qa_document) - format_documents.extend(qa_documents) - except Exception as e: - logging.exception("Failed to format qa document") - - all_qa_documents.extend(format_documents) + return character_splitter # type: ignore def _split_to_documents_for_estimate( self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule @@ -624,11 +501,11 @@ class IndexingRunner: return document_text @staticmethod - def format_split_text(text): + def format_split_text(text: str) -> list[QAPreviewDetail]: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] + return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] def _load( self, @@ -654,13 +531,14 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 chunk_size = 10 + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + # create keyword index + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore + ) + create_keyword_thread.start() - # create keyword index - create_keyword_thread = threading.Thread( - target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore - ) - create_keyword_thread.start() if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] @@ -680,8 +558,8 @@ class IndexingRunner: for future in futures: tokens += future.result() - - create_keyword_thread.join() + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + create_keyword_thread.join() indexing_end_at = time.perf_counter() # update document status to completed @@ -791,28 +669,6 @@ class IndexingRunner: DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - @staticmethod - def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): - """ - Batch add segments index processing - """ - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - # save vector index - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) - def _transform( self, index_processor: BaseIndexProcessor, @@ -854,7 +710,7 @@ class IndexingRunner: ) # add document segments - doc_store.add_documents(documents) + doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) # update document status to indexing cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index f538eaef5b..691cb8d400 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -9,6 +9,8 @@ from typing import Any, Optional, Union from uuid import UUID, uuid4 from flask import current_app +from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( @@ -329,15 +331,15 @@ class TraceTask: ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run = workflow_run + self.workflow_run_id = workflow_run.id if workflow_run else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer - self.kwargs = kwargs self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - self.app_id = None + self.kwargs = kwargs + def execute(self): return self.preprocess() @@ -345,19 +347,23 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - self.workflow_run, self.conversation_id, self.user_id + workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + ), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs ), TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs + ), + TraceTaskName.TOOL_TRACE: lambda: self.tool_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( - self.conversation_id, self.timer, **self.kwargs + conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), } @@ -367,86 +373,100 @@ class TraceTask: def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): - if not workflow_run: - raise ValueError("Workflow run not found") + def workflow_trace( + self, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + ): + if not workflow_run_id: + return {} - db.session.merge(workflow_run) - db.session.refresh(workflow_run) + with Session(db.engine) as session: + workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalars(workflow_run_stmt).first() + if not workflow_run: + raise ValueError("Workflow run not found") - workflow_id = workflow_run.workflow_id - tenant_id = workflow_run.tenant_id - workflow_run_id = workflow_run.id - workflow_run_elapsed_time = workflow_run.elapsed_time - workflow_run_status = workflow_run.status - workflow_run_inputs = workflow_run.inputs_dict - workflow_run_outputs = workflow_run.outputs_dict - workflow_run_version = workflow_run.version - error = workflow_run.error or "" + workflow_id = workflow_run.workflow_id + tenant_id = workflow_run.tenant_id + workflow_run_id = workflow_run.id + workflow_run_elapsed_time = workflow_run.elapsed_time + workflow_run_status = workflow_run.status + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict + workflow_run_version = workflow_run.version + error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = workflow_run.total_tokens - file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" + file_list = workflow_run_inputs.get("sys.file") or [] + query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" - # get workflow_app_log_id - workflow_app_log_data = ( - db.session.query(WorkflowAppLog) - .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) - .first() - ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None - # get message_id - message_data = ( - db.session.query(Message.id) - .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) - .first() - ) - message_id = str(message_data.id) if message_data else None + # get workflow_app_log_id + workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.app_id == workflow_run.app_id, + WorkflowAppLog.workflow_run_id == workflow_run.id, + ) + workflow_app_log_id = session.scalar(workflow_app_log_data_stmt) + # get message_id + message_id = None + if conversation_id: + message_data_stmt = select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_run_id, + ) + message_id = session.scalar(message_data_stmt) - metadata = { - "workflow_id": workflow_id, - "conversation_id": conversation_id, - "workflow_run_id": workflow_run_id, - "tenant_id": tenant_id, - "elapsed_time": workflow_run_elapsed_time, - "status": workflow_run_status, - "version": workflow_run_version, - "total_tokens": total_tokens, - "file_list": file_list, - "triggered_form": workflow_run.triggered_from, - "user_id": user_id, - } - - workflow_trace_info = WorkflowTraceInfo( - workflow_data=workflow_run.to_dict(), - conversation_id=conversation_id, - workflow_id=workflow_id, - tenant_id=tenant_id, - workflow_run_id=workflow_run_id, - workflow_run_elapsed_time=workflow_run_elapsed_time, - workflow_run_status=workflow_run_status, - workflow_run_inputs=workflow_run_inputs, - workflow_run_outputs=workflow_run_outputs, - workflow_run_version=workflow_run_version, - error=error, - total_tokens=total_tokens, - file_list=file_list, - query=query, - metadata=metadata, - workflow_app_log_id=workflow_app_log_id, - message_id=message_id, - start_time=workflow_run.created_at, - end_time=workflow_run.finished_at, - ) + metadata = { + "workflow_id": workflow_id, + "conversation_id": conversation_id, + "workflow_run_id": workflow_run_id, + "tenant_id": tenant_id, + "elapsed_time": workflow_run_elapsed_time, + "status": workflow_run_status, + "version": workflow_run_version, + "total_tokens": total_tokens, + "file_list": file_list, + "triggered_form": workflow_run.triggered_from, + "user_id": user_id, + } + workflow_trace_info = WorkflowTraceInfo( + workflow_data=workflow_run.to_dict(), + conversation_id=conversation_id, + workflow_id=workflow_id, + tenant_id=tenant_id, + workflow_run_id=workflow_run_id, + workflow_run_elapsed_time=workflow_run_elapsed_time, + workflow_run_status=workflow_run_status, + workflow_run_inputs=workflow_run_inputs, + workflow_run_outputs=workflow_run_outputs, + workflow_run_version=workflow_run_version, + error=error, + total_tokens=total_tokens, + file_list=file_list, + query=query, + metadata=metadata, + workflow_app_log_id=workflow_app_log_id, + message_id=message_id, + start_time=workflow_run.created_at, + end_time=workflow_run.finished_at, + ) return workflow_trace_info - def message_trace(self, message_id): + def message_trace(self, message_id: str | None): + if not message_id: + return {} message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() + conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) + conversation_mode = db.session.scalars(conversation_mode_stmt).all() + if not conversation_mode or len(conversation_mode) == 0: + return {} conversation_mode = conversation_mode[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 998eba9ea9..8b06df1930 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -18,7 +18,7 @@ def filter_none_values(data: dict): return new_data -def get_message_data(message_id): +def get_message_data(message_id: str): return db.session.query(Message).filter(Message.id == message_id).first() diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 8b17e8dc0a..a6214d955b 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Optional, cast class JiebaKeywordTableHandler: @@ -8,18 +8,20 @@ class JiebaKeywordTableHandler: from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - jieba.analyse.default_tfidf.stop_words = STOPWORDS + jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba # type: ignore + import jieba.analyse # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, ) + # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + keywords = cast(list[str], keywords) - return set(self._expand_tokens_with_subtokens(keywords)) + return set(self._expand_tokens_with_subtokens(set(keywords))) def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 34343ad60e..3a8200bc7b 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,11 +6,14 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -248,3 +251,89 @@ class RetrievalService: @staticmethod def escape_query_for_search(query: str) -> str: return query.replace('"', '\\"') + + @staticmethod + def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: + records = [] + include_segment_ids = [] + segment_child_map = {} + for document in documents: + document_id = document.metadata.get("document_id") + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + if dataset_document: + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_index_node_id = document.metadata.get("doc_id") + result = ( + db.session.query(ChildChunk, DocumentSegment) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + ChildChunk.index_node_id == child_index_node_id, + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .first() + ) + if result: + child_chunk, segment = result + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.append(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + else: + continue + else: + index_node_id = document.metadata["doc_id"] + + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) + + if not segment: + continue + include_segment_ids.append(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score", None), + } + + records.append(record) + for record in records: + if record["segment"].id in segment_child_map: + record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) + record["score"] = segment_child_map[record["segment"].id]["max_score"] + + return [RetrievalSegments(**record) for record in records] diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 6d16a9bdc2..398b0daad9 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -7,7 +7,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.models.document import Document from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment class DatasetDocumentStore: @@ -60,7 +60,7 @@ class DatasetDocumentStore: return output - def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == self._document_id) @@ -120,13 +120,55 @@ class DatasetDocumentStore: segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) + db.session.flush() + if save_child: + if doc.children: + for postion, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=postion, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) else: segment_document.content = doc.page_content if doc.metadata.get("answer"): segment_document.answer = doc.metadata.pop("answer", "") - segment_document.index_node_hash = doc.metadata["doc_hash"] + segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + if save_child and doc.children: + # delete the existing child chunks + db.session.query(ChildChunk).filter( + ChildChunk.tenant_id == self._dataset.tenant_id, + ChildChunk.dataset_id == self._dataset.id, + ChildChunk.document_id == self._document_id, + ChildChunk.segment_id == segment_document.id, + ).delete() + # add new child chunks + for position, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=position, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) db.session.commit() diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py new file mode 100644 index 0000000000..800422d888 --- /dev/null +++ b/api/core/rag/embedding/retrieval.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel + +from models.dataset import DocumentSegment + + +class RetrievalChildChunk(BaseModel): + """Retrieval segments.""" + + id: str + content: str + score: float + position: int + + +class RetrievalSegments(BaseModel): + """Retrieval segments.""" + + model_config = {"arbitrary_types_allowed": True} + segment: DocumentSegment + child_chunks: Optional[list[RetrievalChildChunk]] = None + score: Optional[float] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index c444105bb5..a3b35458df 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -4,7 +4,7 @@ import os from typing import Optional, cast import pandas as pd -from openpyxl import load_workbook +from openpyxl import load_workbook # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index a473b3dfa7..f9fd7f92a1 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor -from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor from core.rag.extractor.word_extractor import WordExtractor from core.rag.models.document import Document @@ -103,12 +102,11 @@ class ExtractProcessor: input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE - unstructured_api_url = dify_config.UNSTRUCTURED_API_URL - unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - assert unstructured_api_url is not None, "unstructured_api_url is required" - assert unstructured_api_key is not None, "unstructured_api_key is required" extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": + unstructured_api_url = dify_config.UNSTRUCTURED_API_URL + unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or "" + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": @@ -141,11 +139,7 @@ class ExtractProcessor: extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) else: # txt - extractor = ( - UnstructuredTextExtractor(file_path, unstructured_api_url) - if is_automatic - else TextExtractor(file_path, autodetect_encoding=True) - ) + extractor = TextExtractor(file_path, autodetect_encoding=True) else: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 9647dedfff..f1fa5dde5c 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,5 +1,6 @@ import base64 import logging +from typing import Optional from bs4 import BeautifulSoup # type: ignore @@ -15,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 80c29157aa..35ca686f62 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -19,7 +19,7 @@ class UnstructuredEpubExtractor(BaseExtractor): self, file_path: str, api_url: Optional[str] = None, - api_key: Optional[str] = None, + api_key: str = "", ): """Initialize with file path.""" self._file_path = file_path @@ -30,9 +30,6 @@ class UnstructuredEpubExtractor(BaseExtractor): if self._api_url: from unstructured.partition.api import partition_via_api - if self._api_key is None: - raise ValueError("api_key is required") - elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: from unstructured.partition.epub import partition_epub diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 4173d4d122..d5418e612a 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -24,7 +25,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): if the specified encoding fails. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index 57affb8d36..d363449c29 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index e504d4bc23..ecc272a2f0 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index cefe72b290..e7bf6fd2e6 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index ef46ab0e70..916cdc3f2b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c3161bc812..d93de5fef9 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor): if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) - if parsed_paragraph: + if parsed_paragraph.strip(): content.append(parsed_paragraph) + else: + content.append("\n") elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index e42cc44c6f..0845b58e25 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,8 +1,7 @@ from enum import Enum -class IndexType(Enum): +class IndexType(str, Enum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" - PARENT_CHILD_INDEX = "parent_child_index" - SUMMARY_INDEX = "summary_index" + PARENT_CHILD_INDEX = "hierarchical_model" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 7e5efdc66e..2bcd1c79bb 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): raise NotImplementedError - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC): ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + self, + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - character_splitter: TextSplitter - if processing_rule["mode"] == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = processing_rule["rules"] - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, + chunk_size=max_tokens, + chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, @@ -78,4 +81,4 @@ class BaseIndexProcessor(ABC): embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c5ba6295f3..c987edf342 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -3,6 +3,7 @@ from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor @@ -18,9 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX.value: + if self._index_type == IndexType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX.value: + elif self._index_type == IndexType.QA_INDEX: return QAIndexProcessor() + elif self._index_type == IndexType.PARENT_CHILD_INDEX: + return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index c66fa54d50..dca84b9041 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -13,21 +13,40 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset +from models.dataset import Dataset, DatasetProcessRule +from services.entities.knowledge_entities.knowledge_entities import Rule class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if process_rule.get("mode") == "automatic": + automatic_rule = DatasetProcessRule.AUTOMATIC_RULES + rules = Rule(**automatic_rule) + else: + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) # Split the text documents into nodes. + if not rules.segmentation: + raise ValueError("No segmentation found in rules.") splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule", {}), + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] @@ -53,15 +72,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: + keywords_list = kwargs.get("keywords_list") keyword = Keyword(dataset) - keyword.create(documents) + if keywords_list and len(keywords_list) > 0: + keyword.add_texts(documents, keywords_list=keywords_list) + else: + keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py new file mode 100644 index 0000000000..e8423e2b77 --- /dev/null +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -0,0 +1,195 @@ +"""Paragraph index processor.""" + +import uuid +from typing import Optional + +from core.model_manager import ModelInstance +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from libs import helper +from models.dataset import ChildChunk, Dataset, DocumentSegment +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + + +class ParentChildIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), + ) + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) + all_documents = [] # type: ignore + if rules.parent_mode == ParentMode.PARAGRAPH: + # Split the text documents into nodes. + splitter = self._get_splitter( + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, + embedding_model_instance=kwargs.get("embedding_model_instance"), + ) + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, process_rule) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:].strip() + else: + page_content = page_content + if len(page_content) > 0: + document_node.page_content = page_content + # parse document to child nodes + child_nodes = self._split_child_nodes( + document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document_node.children = child_nodes + split_documents.append(document_node) + all_documents.extend(split_documents) + elif rules.parent_mode == ParentMode.FULL_DOC: + page_content = "\n".join([document.page_content for document in documents]) + document = Document(page_content=page_content, metadata=documents[0].metadata) + # parse document to child nodes + child_nodes = self._split_child_nodes( + document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document.children = child_nodes + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash + all_documents.append(document) + + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + for document in documents: + child_documents = document.children + if child_documents: + formatted_child_documents = [ + Document(**child_document.model_dump()) for child_document in child_documents + ] + vector.create(formatted_child_documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + # node_ids is segment's node_ids + if dataset.indexing_technique == "high_quality": + delete_child_chunks = kwargs.get("delete_child_chunks") or False + vector = Vector(dataset) + if node_ids: + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] + vector.delete_by_ids(child_node_ids) + if delete_child_chunks: + db.session.query(ChildChunk).filter( + ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) + ).delete() + db.session.commit() + else: + vector.delete() + + if delete_child_chunks: + db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() + db.session.commit() + + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata["score"] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs + + def _split_child_nodes( + self, + document_node: Document, + rules: Rule, + process_rule_mode: str, + embedding_model_instance: Optional[ModelInstance], + ) -> list[ChildDocument]: + if not rules.subchunk_segmentation: + raise ValueError("No subchunk segmentation found in rules.") + child_splitter = self._get_splitter( + processing_rule_mode=process_rule_mode, + max_tokens=rules.subchunk_segmentation.max_tokens, + chunk_overlap=rules.subchunk_segmentation.chunk_overlap, + separator=rules.subchunk_segmentation.separator, + embedding_model_instance=embedding_model_instance, + ) + # parse document to child nodes + child_nodes = [] + child_documents = child_splitter.split_documents([document_node]) + for child_document_node in child_documents: + if child_document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(child_document_node.page_content) + child_document = ChildDocument( + page_content=child_document_node.page_content, metadata=document_node.metadata + ) + child_document.metadata["doc_id"] = doc_id + child_document.metadata["doc_hash"] = hash + child_page_content = child_document.page_content + if child_page_content.startswith(".") or child_page_content.startswith("。"): + child_page_content = child_page_content[1:].strip() + if len(child_page_content) > 0: + child_document.page_content = child_page_content + child_nodes.append(child_document) + return child_nodes diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 20fd16e8f3..58b50a9fcb 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,18 +21,32 @@ from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import Rule class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + preview = kwargs.get("preview") + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule") or {}, + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, + chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, + separator=rules.segmentation.separator if rules.segmentation else "", embedding_model_instance=kwargs.get("embedding_model_instance"), ) @@ -59,24 +73,33 @@ class QAIndexProcessor(BaseIndexProcessor): document_node.page_content = remove_leading_symbols(page_content) split_documents.append(document_node) all_documents.extend(split_documents) - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self._format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "tenant_id": kwargs.get("tenant_id"), - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": kwargs.get("doc_language", "English"), - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() + if preview: + self._format_qa_document( + current_app._get_current_object(), # type: ignore + kwargs.get("tenant_id"), # type: ignore + all_documents[0], + all_qa_documents, + kwargs.get("doc_language", "English"), + ) + else: + for i in range(0, len(all_documents), 10): + threads = [] + sub_documents = all_documents[i : i + 10] + for doc in sub_documents: + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "tenant_id": kwargs.get("tenant_id"), # type: ignore + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: @@ -98,12 +121,12 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 1e9aaa24f0..421cdc05df 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,20 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel + + +class ChildDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + vector: Optional[list[float]] = None + + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + metadata: dict = {} class Document(BaseModel): @@ -15,10 +28,12 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} provider: Optional[str] = "dify" + children: Optional[list[ChildDocument]] = None + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8a7172f27c..290d9e6e61 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -164,43 +164,29 @@ class DatasetRetrieval: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} # deal with dify documents if dify_documents: - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment if segment.answer: document_context_list.append( DocumentContext( content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) else: document_context_list.append( DocumentContext( content=segment.get_sign_content(), - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) if show_retrieve_source: - for segment in sorted_segments: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = DatasetDocument.query.filter( DatasetDocument.id == segment.document_id, @@ -216,7 +202,7 @@ class DatasetRetrieval: "data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": invoke_from.to_source(), - "score": document_score_list.get(segment.index_node_id, 0.0), + "score": record.score or 0.0, } if invoke_from.to_source() == "dev": diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 38980f6d75..6941ff8fa2 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -267,6 +267,7 @@ class ToolParameter(PluginParameter): :param options: the options of the parameter """ # convert options to ToolParameterOption + # FIXME fix the type error if options: option_objs = [ PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index cfa8e6b8b2..702c4384ae 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -139,7 +139,7 @@ class ToolEngine: error_response = f"tool invoke error: {e}" agent_tool_callback.on_tool_error(e) except ToolEngineInvokeError as e: - meta = e.args[0] + meta = e.meta error_response = f"tool invoke error: {meta.error}" agent_tool_callback.on_tool_error(e) return error_response, [], meta diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 6db9dfd0d9..105823f896 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -12,5 +12,6 @@ def remove_leading_symbols(text: str) -> str: str: The text with leading punctuation or symbols removed. """ # Match Unicode ranges for punctuation and symbols - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+" + # FIXME this pattern is confused quick fix for #11868 maybe refactor it later + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" return re.sub(pattern, "", text) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index b3bcc3b2cc..5c672c985b 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -613,10 +613,10 @@ class Graph(BaseModel): for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): - if node_id in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id2] elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): - if node_id2 in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] branches_merge_node_ids: dict[str, str] = {} diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 8ffb487ec1..f22ea078fb 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -48,9 +48,11 @@ class StreamProcessor(ABC): # we remove the node maybe shortcut the answer node, so comment this code for now # there is not effect on the answer node and the workflow, when we have a better solution # we can open this code. Issues: #11542 #9560 #10638 #10564 - - # reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) - continue + ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) + if "answer" in ids: + continue + else: + reachable_node_ids.extend(ids) else: unreachable_first_node_ids.append(edge.target_node_id) diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py index a815f277be..46613c9e86 100644 --- a/api/core/workflow/nodes/http_request/exc.py +++ b/api/core/workflow/nodes/http_request/exc.py @@ -20,3 +20,7 @@ class ResponseSizeError(HttpRequestNodeError): class RequestBodyError(HttpRequestNodeError): """Raised when the request body is invalid.""" + + +class InvalidURLError(HttpRequestNodeError): + """Raised when the URL is invalid.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index cdfdc6e6d5..fadd142e35 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -23,6 +23,7 @@ from .exc import ( FileFetchError, HttpRequestNodeError, InvalidHttpMethodError, + InvalidURLError, RequestBodyError, ResponseSizeError, ) @@ -66,6 +67,12 @@ class Executor: node_data.authorization.config.api_key ).text + # check if node_data.url is a valid URL + if not node_data.url: + raise InvalidURLError("url is required") + if not node_data.url.startswith(("http://", "https://")): + raise InvalidURLError("url should start with http:// or https://") + self.url: str = node_data.url self.method = node_data.method self.auth = node_data.authorization diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index bfd93c074d..0f239af51a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment @@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus from .entities import KnowledgeRetrievalNodeData @@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: - document_score_list = {} - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, @@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): "document_data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": "workflow", - "score": document_score_list.get(segment.index_node_id, None), + "score": record.score or 0.0, "segment_hit_count": segment.hit_count, "segment_word_count": segment.word_count, "segment_position": segment.position, @@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, reverse=True, ) - position = 1 - for item in retrieval_resource_list: + for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - position += 1 return retrieval_resource_list @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index dc919892e5..d19d6413ed 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file import File, FileTransferMethod, FileType +from core.file import File, FileTransferMethod from core.plugin.manager.exc import PluginDaemonClientSideError from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine @@ -189,10 +189,12 @@ class ToolNode(BaseNode[ToolNodeData]): conversation_id=None, ) - files: list[File] = [] text = "" + files: list[File] = [] json: list[dict] = [] + agent_logs: list[AgentLog] = [] + variables: dict[str, Any] = {} for message in message_stream: @@ -239,14 +241,16 @@ class ToolNode(BaseNode[ToolNodeData]): tool_file = session.scalar(stmt) if tool_file is None: raise ToolFileError(f"tool file {tool_file_id} not exists") + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + files.append( - File( + file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - extension=None, - mime_type=message.meta.get("mime_type", "application/octet-stream"), ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index fcd1547a2f..316be12f5c 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -5,7 +5,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): # register blueprint routers - from flask_cors import CORS + from flask_cors import CORS # type: ignore from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 30f216ff95..26bd6b3577 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -69,6 +69,7 @@ def init_app(app: DifyApp) -> Celery: "schedule.create_tidb_serverless_task", "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", + "schedule.mail_clean_document_notify_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -92,6 +93,11 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_messages.clean_messages", "schedule": timedelta(days=day), }, + # every Monday + "mail_clean_document_notify_task": { + "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", + "schedule": crontab(minute="0", hour="10", day_of_week="1"), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 99c7195b2c..f7b658a58f 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,5 @@ import mimetypes +import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -119,6 +120,11 @@ def _build_from_local_file( upload_file_id = mapping.get("upload_file_id") if not upload_file_id: raise ValueError("Invalid upload file id") + # check if upload_file_id is a valid uuid + try: + uuid.UUID(upload_file_id) + except ValueError: + raise ValueError("Invalid upload file id format") stmt = select(UploadFile).where( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a74e6f54fb..bedab5750f 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -73,6 +73,7 @@ dataset_detail_fields = { "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "tags": fields.List(fields.Nested(tag_fields)), + "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), } diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 2b2ac6243f..f2250d964a 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -34,6 +34,7 @@ document_with_segments_fields = { "data_source_info": fields.Raw(attribute="data_source_info_dict"), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), "dataset_process_rule_id": fields.String, + "process_rule_dict": fields.Raw(attribute="process_rule_dict"), "name": fields.String, "created_from": fields.String, "created_by": fields.String, diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index aaafcab8ab..b9f7e78c17 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -34,8 +34,16 @@ segment_fields = { "document": fields.Nested(document_fields), } +child_chunk_fields = { + "id": fields.String, + "content": fields.String, + "position": fields.Integer, + "score": fields.Float, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 4413af3160..52f89859c9 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore from libs.helper import TimestampField +child_chunk_fields = { + "id": fields.String, + "segment_id": fields.String, + "content": fields.String, + "position": fields.Integer, + "word_count": fields.Integer, + "type": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -20,10 +31,13 @@ segment_fields = { "status": fields.String, "created_by": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, + "updated_by": fields.String, "indexing_at": TimestampField, "completed_at": TimestampField, "error": fields.String, "stopped_at": TimestampField, + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), } segment_list_response = { diff --git a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py new file mode 100644 index 0000000000..9238e5a0a8 --- /dev/null +++ b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py @@ -0,0 +1,55 @@ +"""parent-child-index + +Revision ID: e19037032219 +Revises: 01d6889832f7 +Create Date: 2024-11-22 07:01:17.550037 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e19037032219' +down_revision = 'd7999dfa4aae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('child_chunks', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') + ) + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.drop_index('child_chunk_dataset_id_idx') + + op.drop_table('child_chunks') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py new file mode 100644 index 0000000000..6dadd4e4a8 --- /dev/null +++ b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py @@ -0,0 +1,47 @@ +"""add_auto_disabled_dataset_logs + +Revision ID: 923752d42eb6 +Revises: e19037032219 +Create Date: 2024-12-25 11:37:55.467101 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '923752d42eb6' +down_revision = 'e19037032219' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_auto_disable_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') + ) + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) + batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.drop_index('dataset_auto_disable_log_tenant_idx') + batch_op.drop_index('dataset_auto_disable_log_dataset_idx') + batch_op.drop_index('dataset_auto_disable_log_created_atx') + + op.drop_table('dataset_auto_disable_logs') + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 4f8ca0530f..941dd54687 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -23,7 +23,7 @@ class Account(UserMixin, Base): __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) diff --git a/api/models/dataset.py b/api/models/dataset.py index b9b41dcf47..567f7db432 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account from .engine import db @@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - MODES = ["automatic", "custom"] + MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ @@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] "dataset_id": self.dataset_id, "mode": self.mode, "rules": self.rules_dict, - "created_by": self.created_by, - "created_at": self.created_at, } @property @@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined] .scalar() ) + @property + def process_rule_dict(self): + if self.dataset_process_rule_id: + return self.dataset_process_rule.to_dict() + return None + def to_dict(self): return { "id": self.id, @@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] .first() ) + @property + def child_chunks(self): + process_rule = self.document.dataset_process_rule + if process_rule.mode == "hierarchical": + rules = Rule(**process_rule.rules_dict) + if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: + child_chunks = ( + db.session.query(ChildChunk) + .filter(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + else: + return [] + else: + return [] + def get_sign_content(self): signed_urls = [] text = self.content @@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] return text +class ChildChunk(db.Model): # type: ignore[name-defined] + __tablename__ = "child_chunks" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + segment_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + content = db.Column(db.Text, nullable=False) + word_count = db.Column(db.Integer, nullable=False) + # indexing fields + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def segment(self): + return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + + class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( @@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_auto_disable_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + db.Index("dataset_auto_disable_log_created_atx", "created_at"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/model.py b/api/models/model.py index 39b091b5c9..462fbb672e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -611,13 +611,13 @@ class Conversation(Base): db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) model_id = db.Column(db.String(255), nullable=True) - mode = db.Column(db.String(255), nullable=False) + mode: Mapped[str] = mapped_column(db.String(255)) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -851,7 +851,7 @@ class Message(Base): Index("message_created_at_idx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -878,7 +878,7 @@ class Message(Base): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1403,7 +1403,7 @@ class EndUser(Base, UserMixin): external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - session_id = db.Column(db.String(255), nullable=False) + session_id: Mapped[str] = mapped_column() created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/tools.py b/api/models/tools.py index 0fcd87d2b9..e9b52730d9 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,7 @@ import json from datetime import datetime from typing import Optional +from typing import Any import sqlalchemy as sa from deprecated import deprecated @@ -256,8 +257,8 @@ class ToolConversationVariables(Base): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def variables(self) -> dict: - return dict(json.loads(self.variables_str)) + def variables(self) -> Any: + return json.loads(self.variables_str) class ToolFile(Base): diff --git a/api/models/workflow.py b/api/models/workflow.py index 6e2bdf2392..eba9c1b772 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -402,23 +402,23 @@ class WorkflowRun(Base): db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - sequence_number = db.Column(db.Integer, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) + type: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + version: Mapped[str] = mapped_column(db.String(255)) + graph: Mapped[Optional[str]] = mapped_column(db.Text) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error = db.Column(db.Text) + error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) # account, end_user + created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) @@ -631,29 +631,29 @@ class WorkflowNodeExecution(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - workflow_run_id = db.Column(StringUUID) - index = db.Column(db.Integer, nullable=False) - predecessor_node_id = db.Column(db.String(255)) - node_execution_id = db.Column(db.String(255), nullable=True) - node_id = db.Column(db.String(255), nullable=False) - node_type = db.Column(db.String(255), nullable=False) - title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text) - process_data = db.Column(db.Text) - outputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - finished_at = db.Column(db.DateTime) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + workflow_id: Mapped[str] = mapped_column(StringUUID) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + index: Mapped[int] = mapped_column(db.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_id: Mapped[str] = mapped_column(db.String(255)) + node_type: Mapped[str] = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(db.String(255)) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + process_data: Mapped[Optional[str]] = mapped_column(db.Text) + outputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_by: Mapped[str] = mapped_column(StringUUID) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) @property def created_by_account(self): @@ -760,11 +760,11 @@ class WorkflowAppLog(Base): db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) workflow_id = db.Column(StringUUID, nullable=False) - workflow_run_id = db.Column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 48bdc872f4..5e4d3ec323 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -28,7 +28,6 @@ def clean_messages(): plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING ) - page = 1 while True: try: # Main query with join and filter @@ -79,4 +78,4 @@ def clean_messages(): db.session.query(Message).filter(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() - click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index f66b3c4797..eb73cc285d 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -10,7 +10,7 @@ from configs import dify_config from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DatasetQuery, Document +from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document from services.feature_service import FeatureService @@ -75,6 +75,23 @@ def clean_unused_datasets_task(): ) if not dataset_query or len(dataset_query) == 0: try: + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) @@ -151,6 +168,23 @@ def clean_unused_datasets_task(): else: plan = plan_cache.decode() if plan == "sandbox": + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py new file mode 100644 index 0000000000..fe6839288d --- /dev/null +++ b/api/schedule/mail_clean_document_notify_task.py @@ -0,0 +1,90 @@ +import logging +import time +from collections import defaultdict + +import click +from flask import render_template # type: ignore + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_mail import mail +from models.account import Account, Tenant, TenantAccountJoin +from models.dataset import Dataset, DatasetAutoDisableLog +from services.feature_service import FeatureService + + +@app.celery.task(queue="dataset") +def send_document_clean_notify_task(): + """ + Async Send document clean notify mail + + Usage: send_document_clean_notify_task.delay() + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start send document clean notify mail", fg="green")) + start_at = time.perf_counter() + + # send document clean notify mail + try: + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != "sandbox": + knowledge_details = [] + # check tenant + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + # check current owner + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue + + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + html_content = render_template( + "clean_document_job_mail_template-US.html", + userName=account.email, + knowledge_details=knowledge_details, + url=url, + ) + mail.send( + to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content + ) + + # update notified to True + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + dataset_auto_disable_log.notified = True + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send document clean notify mail failed") diff --git a/api/services/account_service.py b/api/services/account_service.py index 13b70db580..214fb88995 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -798,6 +798,7 @@ class RegisterService: language: Optional[str] = None, status: Optional[AccountStatus] = None, is_setup: Optional[bool] = False, + create_workspace_required: Optional[bool] = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -815,7 +816,7 @@ class RegisterService: if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace: + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 7793fdc4ff..932d68bea1 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,7 +4,7 @@ from enum import StrEnum from typing import Optional, cast from uuid import uuid4 -import yaml +import yaml # type: ignore from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -196,6 +196,9 @@ class AppDslService: data["kind"] = "app" imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") status = _check_version_compatibility(imported_version) # Extract app data @@ -524,7 +527,7 @@ class AppDslService: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) + return yaml.dump(export_data, allow_unicode=True) # type: ignore @classmethod def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 973110f515..ef52301c0a 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,5 +1,6 @@ import io import logging +import uuid from typing import Optional from werkzeug.datastructures import FileStorage @@ -122,6 +123,10 @@ class AudioService: raise e if message_id: + try: + uuid.UUID(message_id) + except ValueError: + return None message = db.session.query(Message).filter(Message.id == message_id).first() if message is None: return None diff --git a/api/services/billing_service.py b/api/services/billing_service.py index d980186488..ed611a8be4 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -2,7 +2,7 @@ import os from typing import Optional import httpx -from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed +from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from extensions.ext_database import db from models.account import TenantAccountJoin, TenantAccountRole @@ -44,7 +44,7 @@ class BillingService: @retry( wait=wait_fixed(2), stop=stop_before_delay(10), - retry=retry_if_not_exception_type(httpx.RequestError), + retry=retry_if_exception_type(httpx.RequestError), reraise=True, ) def _send_request(cls, method, endpoint, json=None, params=None): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ca741f1935..b7ddd14025 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -14,6 +14,7 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -23,7 +24,9 @@ from libs import helper from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, + ChildChunk, Dataset, + DatasetAutoDisableLog, DatasetCollectionBinding, DatasetPermission, DatasetPermissionEnum, @@ -35,8 +38,15 @@ from models.dataset import ( ) from models.model import UploadFile from models.source import DataSourceOauthBinding -from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity -from services.errors.account import NoPermissionError +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + KnowledgeConfig, + RerankingModel, + RetrievalModel, + SegmentUpdateArgs, +) +from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError @@ -44,13 +54,16 @@ from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task +from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -408,6 +421,24 @@ class DatasetService: .all() ) + @staticmethod + def get_dataset_auto_disable_logs(dataset_id: str) -> dict: + # get recent 30 days auto disable logs + start_date = datetime.datetime.now() - datetime.timedelta(days=30) + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ).all() + if dataset_auto_disable_logs: + return { + "document_ids": [log.document_id for log in dataset_auto_disable_logs], + "count": len(dataset_auto_disable_logs), + } + return { + "document_ids": [], + "count": 0, + } + class DocumentService: DEFAULT_RULES = { @@ -518,12 +549,14 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - return document + def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + if document_id: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + return document + else: + return None @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: @@ -588,6 +621,20 @@ class DocumentService: db.session.delete(document) db.session.commit() + @staticmethod + def delete_documents(dataset: Dataset, document_ids: list[str]): + documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + file_ids = [ + document.data_source_info_dict["upload_file_id"] + for document in documents + if document.data_source_type == "upload_file" + ] + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + + for document in documents: + db.session.delete(document) + db.session.commit() + @staticmethod def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) @@ -689,7 +736,7 @@ class DocumentService: @staticmethod def save_document_with_dataset_id( dataset: Dataset, - document_data: dict, + knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", @@ -698,37 +745,35 @@ class DocumentService: features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if "original_document_id" not in document_data or not document_data["original_document_id"]: + if not knowledge_config.original_document_id: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] - count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - DocumentService.check_documents_upload_quota(count, features) + DocumentService.check_documents_upload_quota(count, features) # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = document_data["data_source"]["type"] + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore if not dataset.indexing_technique: - if ( - "indexing_technique" not in document_data - or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST - ): - raise ValueError("Indexing technique is required") + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == "high_quality": + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING @@ -748,46 +793,47 @@ class DocumentService: "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model + dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore documents = [] - if document_data.get("original_document_id"): - document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) documents.append(document) batch = document.batch else: batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), - created_by=account.id, - ) - elif process_rule["mode"] == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - else: - logging.warn( - f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" - ) - return - db.session.add(dataset_process_rule) - db.session.commit() + process_rule = knowledge_config.process_rule + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) with redis_client.lock(lock_name, timeout=600): position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -804,7 +850,7 @@ class DocumentService: "upload_file_id": file_id, } # check duplicate - if document_data.get("duplicate", False): + if knowledge_config.duplicate: document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, @@ -813,11 +859,11 @@ class DocumentService: name=file_name, ).first() if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] - document.doc_language = document_data["doc_language"] + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch document.indexing_status = "waiting" @@ -827,10 +873,10 @@ class DocumentService: continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -843,8 +889,10 @@ class DocumentService: document_ids.append(document.id) documents.append(document) position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if not notion_info_list: + raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( @@ -859,7 +907,7 @@ class DocumentService: exist_page_ids.append(data_source_info["notion_page_id"]) exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -870,25 +918,25 @@ class DocumentService: ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: - if page["page_id"] not in exist_page_ids: + for page in notion_info.pages: + if page.page_id not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, } document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, account, - page["page_name"], + page.page_name, batch, ) db.session.add(document) @@ -897,19 +945,21 @@ class DocumentService: documents.append(document) position += 1 else: - exist_document.pop(page["page_id"]) + exist_document.pop(page.page_id) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls for url in urls: data_source_info = { "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, "mode": "crawl", } if len(url) > 255: @@ -918,10 +968,10 @@ class DocumentService: document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -995,31 +1045,31 @@ class DocumentService: @staticmethod def update_document_with_dataset_id( dataset: Dataset, - document_data: dict, + document_data: KnowledgeConfig, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) - document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) + document = DocumentService.get_document(dataset.id, document_data.original_document_id) if document is None: raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") # save process rule - if document_data.get("process_rule"): - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": + if document_data.process_rule: + process_rule = document_data.process_rule + if process_rule.mode in {"custom", "hierarchical"}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) - elif process_rule["mode"] == "automatic": + elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], + mode=process_rule.mode, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -1028,11 +1078,13 @@ class DocumentService: db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get("data_source"): + if document_data.data_source: file_name = "" data_source_info = {} - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if document_data.data_source.info_list.data_source_type == "upload_file": + if not document_data.data_source.info_list.file_info_list: + raise ValueError("No file info list found.") + upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1048,10 +1100,12 @@ class DocumentService: data_source_info = { "upload_file_id": file_id, } - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif document_data.data_source.info_list.data_source_type == "notion_import": + if not document_data.data_source.info_list.notion_info_list: + raise ValueError("No notion info list found.") + notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -1062,31 +1116,32 @@ class DocumentService: ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: + for page in notion_info.pages: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore + "type": page.type, } - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] - for url in urls: - data_source_info = { - "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), - "mode": "crawl", - } - document.data_source_type = document_data["data_source"]["type"] + elif document_data.data_source.info_list.data_source_type == "website_crawl": + website_info = document_data.data_source.info_list.website_info_list + if website_info: + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, # type: ignore + "mode": "crawl", + } + document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document name - if document_data.get("name"): - document.name = document_data["name"] + if document_data.name: + document.name = document_data.name # update document to be waiting document.indexing_status = "waiting" document.completed_at = None @@ -1096,7 +1151,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] + document.doc_form = document_data.doc_form db.session.add(document) db.session.commit() # update document segment @@ -1108,21 +1163,27 @@ class DocumentService: return document @staticmethod - def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): + def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = ( + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list + else [] + ) count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if notion_info_list: + for notion_info in notion_info_list: + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if website_info: + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1131,39 +1192,39 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if document_data["indexing_technique"] == "high_quality": + if knowledge_config.indexing_technique == "high_quality": dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - document_data["embedding_model_provider"], document_data["embedding_model"] + knowledge_config.embedding_model_provider, # type: ignore + knowledge_config.embedding_model, # type: ignore ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get("retrieval_model"): - retrieval_model = document_data["retrieval_model"] + if knowledge_config.retrieval_model: + retrieval_model = knowledge_config.retrieval_model else: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - retrieval_model = default_retrieval_model + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH.value, + reranking_enable=False, + reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), + top_k=2, + score_threshold_enabled=False, + ) # save dataset dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data.get("indexing_technique", "high_quality"), + data_source_type=knowledge_config.data_source.info_list.data_source_type, + indexing_technique=knowledge_config.indexing_technique, created_by=account.id, - embedding_model=document_data.get("embedding_model"), - embedding_model_provider=document_data.get("embedding_model_provider"), + embedding_model=knowledge_config.embedding_model, + embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model, + retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) + db.session.add(dataset) # type: ignore db.session.flush() - documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) cut_length = 18 cut_name = documents[0].name[:cut_length] @@ -1174,133 +1235,86 @@ class DocumentService: return dataset, documents, batch @classmethod - def document_create_args_validate(cls, args: dict): - if "original_document_id" not in args or not args["original_document_id"]: - DocumentService.data_source_args_validate(args) - DocumentService.process_rule_args_validate(args) + def document_create_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source and not knowledge_config.process_rule: + raise ValueError("Data source or Process rule is required") else: - if ("data_source" not in args or not args["data_source"]) and ( - "process_rule" not in args or not args["process_rule"] - ): - raise ValueError("Data source or Process rule is required") - else: - if args.get("data_source"): - DocumentService.data_source_args_validate(args) - if args.get("process_rule"): - DocumentService.process_rule_args_validate(args) + if knowledge_config.data_source: + DocumentService.data_source_args_validate(knowledge_config) + if knowledge_config.process_rule: + DocumentService.process_rule_args_validate(knowledge_config) @classmethod - def data_source_args_validate(cls, args: dict): - if "data_source" not in args or not args["data_source"]: + def data_source_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source: raise ValueError("Data source is required") - if not isinstance(args["data_source"], dict): - raise ValueError("Data source is invalid") - - if "type" not in args["data_source"] or not args["data_source"]["type"]: - raise ValueError("Data source type is required") - - if args["data_source"]["type"] not in Document.DATA_SOURCES: + if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: + if not knowledge_config.data_source.info_list: raise ValueError("Data source info is required") - if args["data_source"]["type"] == "upload_file": - if ( - "file_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["file_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") - if args["data_source"]["type"] == "notion_import": - if ( - "notion_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["notion_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "notion_import": + if not knowledge_config.data_source.info_list.notion_info_list: raise ValueError("Notion source info is required") - if args["data_source"]["type"] == "website_crawl": - if ( - "website_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["website_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "website_crawl": + if not knowledge_config.data_source.info_list.website_info_list: raise ValueError("Website source info is required") @classmethod - def process_rule_args_validate(cls, args: dict): - if "process_rule" not in args or not args["process_rule"]: + def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.process_rule: raise ValueError("Process rule is required") - if not isinstance(args["process_rule"], dict): - raise ValueError("Process rule is invalid") - - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: + if not knowledge_config.process_rule.mode: raise ValueError("Process rule mode is required") - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: + if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": - args["process_rule"]["rules"] = {} + if knowledge_config.process_rule.mode == "automatic": + knowledge_config.process_rule.rules = None else: - if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: + if not knowledge_config.process_rule.rules: raise ValueError("Process rule rules is required") - if not isinstance(args["process_rule"]["rules"], dict): - raise ValueError("Process rule rules is invalid") - - if ( - "pre_processing_rules" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["pre_processing_rules"] is None - ): + if knowledge_config.process_rule.rules.pre_processing_rules is None: raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): - raise ValueError("Process rule pre_processing_rules is invalid") - unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: - if "id" not in pre_processing_rule or not pre_processing_rule["id"]: + for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules: + if not pre_processing_rule.id: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: - raise ValueError("Process rule pre_processing_rules id is invalid") - - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: - raise ValueError("Process rule pre_processing_rules enabled is required") - - if not isinstance(pre_processing_rule["enabled"], bool): + if not isinstance(pre_processing_rule.enabled, bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) + knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) - if ( - "segmentation" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["segmentation"] is None - ): + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): - raise ValueError("Process rule segmentation is invalid") - - if ( - "separator" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["separator"] - ): + if not knowledge_config.process_rule.rules.segmentation.separator: raise ValueError("Process rule segmentation separator is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): + if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str): raise ValueError("Process rule segmentation separator is invalid") - if ( - "max_tokens" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + if not ( + knowledge_config.process_rule.mode == "hierarchical" + and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): - raise ValueError("Process rule segmentation max_tokens is required") + if not knowledge_config.process_rule.rules.segmentation.max_tokens: + raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): - raise ValueError("Process rule segmentation max_tokens is invalid") + if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int): + raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): @@ -1447,7 +1461,7 @@ class SegmentService: # save vector index try: - VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False @@ -1528,7 +1542,7 @@ class SegmentService: db.session.add(document) try: # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") for segment_document in segment_data_list: @@ -1540,14 +1554,13 @@ class SegmentService: return segment_data_list @classmethod - def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - segment_update_entity = SegmentUpdateEntity(**args) + def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if segment_update_entity.enabled is not None: - action = segment_update_entity.enabled + if args.enabled is not None: + action = args.enabled if segment.enabled != action: if not action: segment.enabled = action @@ -1560,22 +1573,22 @@ class SegmentService: disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if segment_update_entity.enabled is not None: - if not segment_update_entity.enabled: + if args.enabled is not None: + if not args.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: word_count_change = segment.word_count - content = segment_update_entity.content + content = args.content or segment.content if segment.content == content: segment.word_count = len(content) if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer or "") + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change - if segment_update_entity.keywords: - segment.keywords = segment_update_entity.keywords + if args.keywords: + segment.keywords = args.keywords segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1586,9 +1599,45 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if segment_update_entity.enabled: - keywords = segment_update_entity.keywords or [] - VectorService.create_segments_vector([keywords], [segment], dataset) + if args.enabled: + VectorService.create_segments_vector( + [args.keywords] if args.keywords else None, + [segment], + dataset, + document.doc_form, + ) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # regenerate child chunks + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1619,8 +1668,8 @@ class SegmentService: segment.disabled_at = None segment.disabled_by = None if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer or "") + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1628,8 +1677,40 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - # update segment vector index - VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # update segment vector index + VectorService.update_segment_vector(args.keywords, segment, dataset) except Exception as e: logging.exception("update segment index failed") @@ -1652,13 +1733,265 @@ class SegmentService: if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count document.word_count -= segment.word_count db.session.add(document) db.session.commit() + @classmethod + def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + index_node_ids = ( + DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .all() + ) + index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.commit() + + @classmethod + def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + if action == "enable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + elif action == "disable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + else: + raise InvalidActionError() + + @classmethod + def create_child_chunk( + cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset + ) -> ChildChunk: + lock_name = "add_child_lock_{}".format(segment.id) + with redis_client.lock(lock_name, timeout=20): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(content) + child_chunk_count = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .count() + ) + max_position = ( + db.session.query(func.max(ChildChunk.position)) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .scalar() + ) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=max_position + 1, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=content, + word_count=len(content), + type="customized", + created_by=current_user.id, + ) + db.session.add(child_chunk) + # save vector index + try: + VectorService.create_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("create child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + db.session.commit() + + return child_chunk + + @classmethod + def update_child_chunks( + cls, + child_chunks_update_args: list[ChildChunkUpdateArgs], + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> list[ChildChunk]: + child_chunks = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .all() + ) + child_chunks_map = {chunk.id: chunk for chunk in child_chunks} + + new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] + + for child_chunk_update_args in child_chunks_update_args: + if child_chunk_update_args.id: + child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None) + if child_chunk: + if child_chunk.content != child_chunk_update_args.content: + child_chunk.content = child_chunk_update_args.content + child_chunk.word_count = len(child_chunk.content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + update_child_chunks.append(child_chunk) + else: + new_child_chunks_args.append(child_chunk_update_args) + if child_chunks_map: + delete_child_chunks = list(child_chunks_map.values()) + try: + if update_child_chunks: + db.session.bulk_save_objects(update_child_chunks) + + if delete_child_chunks: + for child_chunk in delete_child_chunks: + db.session.delete(child_chunk) + if new_child_chunks_args: + child_chunk_count = len(child_chunks) + for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(args.content) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=position, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=args.content, + word_count=len(args.content), + type="customized", + created_by=current_user.id, + ) + + db.session.add(child_chunk) + db.session.flush() + new_child_chunks.append(child_chunk) + VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position) + + @classmethod + def update_child_chunk( + cls, + content: str, + child_chunk: ChildChunk, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> ChildChunk: + try: + child_chunk.content = content + child_chunk.word_count = len(content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + db.session.add(child_chunk) + VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return child_chunk + + @classmethod + def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset): + db.session.delete(child_chunk) + try: + VectorService.delete_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("delete child chunk index failed") + db.session.rollback() + raise ChildChunkDeleteIndexError(str(e)) + db.session.commit() + + @classmethod + def get_child_chunks( + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + ): + query = ChildChunk.query.filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ).order_by(ChildChunk.position.asc()) + if keyword: + query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 449b79f339..76d9c28812 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,4 +1,5 @@ -from typing import Optional +from enum import Enum +from typing import Literal, Optional from pydantic import BaseModel @@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel): answer: Optional[str] = None keywords: Optional[list[str]] = None enabled: Optional[bool] = None + + +class ParentMode(str, Enum): + FULL_DOC = "full-doc" + PARAGRAPH = "paragraph" + + +class NotionIcon(BaseModel): + type: str + url: Optional[str] = None + emoji: Optional[str] = None + + +class NotionPage(BaseModel): + page_id: str + page_name: str + page_icon: Optional[NotionIcon] = None + type: str + + +class NotionInfo(BaseModel): + workspace_id: str + pages: list[NotionPage] + + +class WebsiteInfo(BaseModel): + provider: str + job_id: str + urls: list[str] + only_main_content: bool = True + + +class FileInfo(BaseModel): + file_ids: list[str] + + +class InfoList(BaseModel): + data_source_type: Literal["upload_file", "notion_import", "website_crawl"] + notion_info_list: Optional[list[NotionInfo]] = None + file_info_list: Optional[FileInfo] = None + website_info_list: Optional[WebsiteInfo] = None + + +class DataSource(BaseModel): + info_list: InfoList + + +class PreProcessingRule(BaseModel): + id: str + enabled: bool + + +class Segmentation(BaseModel): + separator: str = "\n" + max_tokens: int + chunk_overlap: int = 0 + + +class Rule(BaseModel): + pre_processing_rules: Optional[list[PreProcessingRule]] = None + segmentation: Optional[Segmentation] = None + parent_mode: Optional[Literal["full-doc", "paragraph"]] = None + subchunk_segmentation: Optional[Segmentation] = None + + +class ProcessRule(BaseModel): + mode: Literal["automatic", "custom", "hierarchical"] + rules: Optional[Rule] = None + + +class RerankingModel(BaseModel): + reranking_provider_name: Optional[str] = None + reranking_model_name: Optional[str] = None + + +class RetrievalModel(BaseModel): + search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + reranking_enable: bool + reranking_model: Optional[RerankingModel] = None + top_k: int + score_threshold_enabled: bool + score_threshold: Optional[float] = None + + +class KnowledgeConfig(BaseModel): + original_document_id: Optional[str] = None + duplicate: bool = True + indexing_technique: Literal["high_quality", "economy"] + data_source: DataSource + process_rule: Optional[ProcessRule] = None + retrieval_model: Optional[RetrievalModel] = None + doc_form: str = "text_model" + doc_language: str = "English" + embedding_model: Optional[str] = None + embedding_model_provider: Optional[str] = None + name: Optional[str] = None + + +class SegmentUpdateArgs(BaseModel): + content: Optional[str] = None + answer: Optional[str] = None + keywords: Optional[list[str]] = None + regenerate_child_chunks: bool = False + enabled: Optional[bool] = None + + +class ChildChunkUpdateArgs(BaseModel): + id: Optional[str] = None + content: str diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 4d39f956b8..35ea28468e 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,6 @@ from typing import Optional -class BaseServiceError(Exception): +class BaseServiceError(ValueError): def __init__(self, description: Optional[str] = None): self.description = description diff --git a/api/services/errors/chunk.py b/api/services/errors/chunk.py new file mode 100644 index 0000000000..75bf4d5d5f --- /dev/null +++ b/api/services/errors/chunk.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class ChildChunkIndexingError(BaseServiceError): + description = "{message}" + + +class ChildChunkDeleteIndexError(BaseServiceError): + description = "{message}" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 36c79d7045..a42b3020cd 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -76,7 +76,7 @@ class FeatureService: cls._fulfill_params_from_env(features) - if dify_config.BILLING_ENABLED: + if dify_config.BILLING_ENABLED and tenant_id: cls._fulfill_params_from_billing_api(features, tenant_id) return features diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 41b4e1ec46..e9176fc1c6 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -7,7 +7,7 @@ from core.rag.models.document import Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account -from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Dataset, DatasetQuery default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -69,7 +69,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return dict(cls.compact_retrieve_response(dataset, query, all_documents)) + return cls.compact_retrieve_response(query, all_documents) # type: ignore @classmethod def external_retrieve( @@ -106,41 +106,14 @@ class HitTestingService: return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): - records = [] - - for document in documents: - if document.metadata is None: - continue - - index_node_id = document.metadata["doc_id"] - - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() - ) - - if not segment: - continue - - record = { - "segment": segment, - "score": document.metadata.get("score", None), - } - - records.append(record) + def compact_retrieve_response(cls, query: str, documents: list[Document]): + records = RetrievalService.format_retrieval_documents(documents) return { "query": { "content": query, }, - "records": records, + "records": [record.model_dump() for record in records], } @classmethod diff --git a/api/services/message_service.py b/api/services/message_service.py index c4447a84da..c17122ef64 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -152,6 +152,7 @@ class MessageService: @classmethod def create_feedback( cls, + *, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b5565e986d..c4b9db69ec 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -64,7 +64,10 @@ class ToolTransformService: ) elif isinstance(provider, ToolProviderApiEntity): if provider.plugin_id: - provider.icon = ToolTransformService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon) + if isinstance(provider.icon, str): + provider.icon = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.icon + ) else: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3c67351335..92422bf29d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,40 +1,70 @@ from typing import Optional +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from models.dataset import Dataset, DocumentSegment +from extensions.ext_database import db +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.entities.knowledge_entities.knowledge_entities import ParentMode class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents = [] + for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - if dataset.indexing_technique == "high_quality": - # save vector index - vector = Vector(dataset=dataset) - vector.add_texts(documents, duplicate_check=True) + if doc_form == IndexType.PARENT_CHILD_INDEX: + document = DatasetDocument.query.filter_by(id=segment.document_id).first() + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() - # save keyword index - keyword = Keyword(dataset) - - if keywords_list and len(keywords_list) > 0: - keyword.add_texts(documents, keywords_list=keywords_list) - else: - keyword.add_texts(documents) + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) + else: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + documents.append(document) + if len(documents) > 0: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): @@ -65,3 +95,123 @@ class VectorService: keyword.add_texts([document], keywords_list=[keywords]) else: keyword.add_texts([document]) + + @classmethod + def generate_child_chunks( + cls, + segment: DocumentSegment, + dataset_document: DatasetDocument, + dataset: Dataset, + embedding_model_instance: ModelInstance, + processing_rule: DatasetProcessRule, + regenerate: bool = False, + ): + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + if regenerate: + # delete child chunks + index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) + + # generate child chunks + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + # use full doc mode to generate segment's child chunk + processing_rule_dict = processing_rule.to_dict() + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + documents = index_processor.transform( + [document], + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule_dict, + tenant_id=dataset.tenant_id, + doc_language=dataset_document.doc_language, + ) + # save child chunks + if documents and documents[0].children: + index_processor.load(dataset, documents) + + for position, child_chunk in enumerate(documents[0].children, start=1): + child_segment = ChildChunk( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=dataset_document.id, + segment_id=segment.id, + position=position, + index_node_id=child_chunk.metadata["doc_id"], + index_node_hash=child_chunk.metadata["doc_hash"], + content=child_chunk.page_content, + word_count=len(child_chunk.page_content), + type="automatic", + created_by=dataset_document.created_by, + ) + db.session.add(child_segment) + db.session.commit() + + @classmethod + def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): + child_document = Document( + page_content=child_segment.content, + metadata={ + "doc_id": child_segment.index_node_id, + "doc_hash": child_segment.index_node_hash, + "document_id": child_segment.document_id, + "dataset_id": child_segment.dataset_id, + }, + ) + if dataset.indexing_technique == "high_quality": + # save vector index + vector = Vector(dataset=dataset) + vector.add_texts([child_document], duplicate_check=True) + + @classmethod + def update_child_chunk_vector( + cls, + new_child_chunks: list[ChildChunk], + update_child_chunks: list[ChildChunk], + delete_child_chunks: list[ChildChunk], + dataset: Dataset, + ): + documents = [] + delete_node_ids = [] + for new_child_chunk in new_child_chunks: + new_child_document = Document( + page_content=new_child_chunk.content, + metadata={ + "doc_id": new_child_chunk.index_node_id, + "doc_hash": new_child_chunk.index_node_hash, + "document_id": new_child_chunk.document_id, + "dataset_id": new_child_chunk.dataset_id, + }, + ) + documents.append(new_child_document) + for update_child_chunk in update_child_chunks: + child_document = Document( + page_content=update_child_chunk.content, + metadata={ + "doc_id": update_child_chunk.index_node_id, + "doc_hash": update_child_chunk.index_node_hash, + "document_id": update_child_chunk.document_id, + "dataset_id": update_child_chunk.dataset_id, + }, + ) + documents.append(child_document) + delete_node_ids.append(update_child_chunk.index_node_id) + for delete_child_chunk in delete_child_chunks: + delete_node_ids.append(delete_child_chunk.index_node_id) + if dataset.indexing_technique == "high_quality": + # update vector index + vector = Vector(dataset=dataset) + if delete_node_ids: + vector.delete_by_ids(delete_node_ids) + if documents: + vector.add_texts(documents, duplicate_check=True) + + @classmethod + def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): + vector = Vector(dataset=dataset) + vector.delete_by_ids([child_chunk.index_node_id]) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 95649106e2..2de3d0ac55 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime from typing import Any, Optional +from uuid import uuid4 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -333,6 +334,7 @@ class WorkflowService: error = e.error workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) workflow_node_execution.tenant_id = tenant_id workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.index = 1 diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 50bb2b6e63..9a172b2d9d 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -6,12 +6,13 @@ import click from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment @shared_task(queue="dataset") @@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) dataset = dataset_document.dataset @@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) + # delete auto disable log + db.session.query(DatasetAutoDisableLog).filter( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + db.session.commit() + end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py new file mode 100644 index 0000000000..3bae82a5e3 --- /dev/null +++ b/api/tasks/batch_clean_document_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, DocumentSegment +from models.model import UploadFile + + +@shared_task(queue="dataset") +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + :param doc_form: doc_form + :param file_ids: file ids + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + if image_file and image_file.key: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.commit() + if file_ids: + files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + for file in files: + try: + storage.delete(file.key) + except Exception: + logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) + db.session.delete(file) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index ce3d65526c..3238842307 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,13 +7,13 @@ import click from celery import shared_task # type: ignore from sqlalchemy import func -from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper from models.dataset import Dataset, Document, DocumentSegment +from services.vector_service import VectorService @shared_task(queue="dataset") @@ -98,8 +98,7 @@ def batch_create_segment_to_index_task( dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db - indexing_runner = IndexingRunner() - indexing_runner.batch_add_segments(document_segments, dataset) + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) db.session.commit() redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index c48eb2e320..4d77f1fb65 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -62,7 +62,7 @@ def clean_dataset_task( if doc_form is None: raise ValueError("Index type must be specified.") index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None) + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 05eb9fd625..5a4d7a52b2 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index f5d6406d9c..5a6eb00a62 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index b025509aeb..0efc924a77 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -4,8 +4,9 @@ import time import click from celery import shared_task # type: ignore +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): db.session.commit() # clean index - index_processor.clean(dataset, None, with_keywords=False) + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) for dataset_document in dataset_documents: # update from vector index @@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 45a612c745..3b04143dd9 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,48 +6,38 @@ from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from extensions.ext_redis import redis_client from models.dataset import Dataset, Document @shared_task(queue="dataset") -def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): +def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): """ Async Remove segment from index - :param segment_id: - :param index_node_id: + :param index_node_ids: :param dataset_id: :param document_id: - Usage: delete_segment_from_index_task.delay(segment_id) + Usage: delete_segment_from_index_task.delay(segment_ids) """ - logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) + logging.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() - indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, [index_node_id]) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() - logging.info( - click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") - ) + logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) except Exception: logging.exception("delete segment from index failed") - finally: - redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py new file mode 100644 index 0000000000..67112666e7 --- /dev/null +++ b/api/tasks/disable_segments_from_index_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async disable segments from index + :param segment_ids: + + Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) + except Exception: + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index ac4e81f95d..d686698b9a 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 5f1e9a892f..d8f14830c9 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6db2620eb6..8e1d2b6b5d 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 2f6eb7b82a..76522f4720 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -6,8 +6,9 @@ import click from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents # save vector index index_processor.load(dataset, [document]) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py new file mode 100644 index 0000000000..0864e05e25 --- /dev/null +++ b/api/tasks/enable_segments_to_index_task.py @@ -0,0 +1,108 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async enable segments to index + :param segment_ids: + + Usage: enable_segments_to_index_task.delay(segment_ids) + """ + start_at = time.perf_counter() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + if not segments: + return + + try: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents) + + end_at = time.perf_counter() + logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) + except Exception as e: + logging.exception("enable segments to index failed") + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + "enabled": False, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 4ba6d1a83e..1d580b3802 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str): index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 485caa5152..74fd542f6c 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(retry_indexing_cache_key) @@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 5d6b069cf4..8da050d0d1 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(sync_indexing_cache_key) @@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html new file mode 100644 index 0000000000..88e78f41c7 --- /dev/null +++ b/api/templates/clean_document_job_mail_template-US.html @@ -0,0 +1,100 @@ + + + + + + Documents Disabled Notification + + + + + + \ No newline at end of file diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7122f4a6d0..e65ca45858 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -15,15 +15,15 @@ x-shared-env: &shared-api-worker-env LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} - LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} + LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} - CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"} - OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"} + CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai} + OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} @@ -69,7 +69,7 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} - CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"} + CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} @@ -88,13 +88,13 @@ x-shared-env: &shared-api-worker-env AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} - AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"} + AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-https://.blob.core.windows.net} GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name} GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string} ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name} ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key} ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key} - ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"} + ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-https://oss-ap-southeast-1-internal.aliyuncs.com} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} @@ -103,7 +103,7 @@ x-shared-env: &shared-api-worker-env TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} - OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"} + OCI_ENDPOINT: ${OCI_ENDPOINT:-https://objectstorage.us-ashburn-1.oraclecloud.com} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key} @@ -125,14 +125,14 @@ x-shared-env: &shared-api-worker-env SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} - WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"} + WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} - QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"} + QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"} + MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} @@ -142,7 +142,7 @@ x-shared-env: &shared-api-worker-env MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-} MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify} MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-} - COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-couchbase://couchbase-server} COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} @@ -176,15 +176,15 @@ x-shared-env: &shared-api-worker-env TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} - TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"} + TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false} TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334} TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify} TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify} - TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"} - TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"} + TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1} + TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1} TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1} TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify} TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100} @@ -209,7 +209,7 @@ x-shared-env: &shared-api-worker-env OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} - TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"} + TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify} @@ -221,7 +221,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} - BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} @@ -235,7 +235,7 @@ x-shared-env: &shared-api-worker-env VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30} VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30} - LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"} + LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} @@ -245,7 +245,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} - UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"} + UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} @@ -270,7 +270,7 @@ x-shared-env: &shared-api-worker-env NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} MAIL_TYPE: ${MAIL_TYPE:-resend} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} - RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"} + RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} @@ -281,7 +281,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} - CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"} + CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} @@ -303,8 +303,8 @@ x-shared-env: &shared-api-worker-env WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} - SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"} - SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"} + SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} + SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} PGUSER: ${PGUSER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} @@ -314,8 +314,8 @@ x-shared-env: &shared-api-worker-env SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} - SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"} - SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} + SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} @@ -338,8 +338,8 @@ x-shared-env: &shared-api-worker-env ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} - ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"} - MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"} + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} + MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true} PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres} PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} @@ -360,7 +360,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -374,7 +374,7 @@ x-shared-env: &shared-api-worker-env SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} - COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"} + COMPOSE_PROFILES: ${COMPOSE_PROFILES:-${VECTOR_STORE:-weaviate}} EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose index 54b6d55217..dc4460f96c 100755 --- a/docker/generate_docker_compose +++ b/docker/generate_docker_compose @@ -43,7 +43,7 @@ def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"): else: # If default value contains special characters, wrap it in quotes if re.search(r"[:\s]", default): - default = f'"{default}"' + default = f"{default}" lines.append(f" {key}: ${{{key}:-{default}}}") return "\n".join(lines) diff --git a/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts b/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts index fae5357bfc..7f1f4ba659 100644 --- a/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts +++ b/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts @@ -37,7 +37,7 @@ function useAppsQueryState() { const syncSearchParams = useCallback((params: URLSearchParams) => { const search = params.toString() const query = search ? `?${search}` : '' - router.push(`${pathname}${query}`) + router.push(`${pathname}${query}`, { scroll: false }) }, [router, pathname]) // Update the URL search string whenever the query changes. diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index b416659a6a..a6fb116fa8 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -7,85 +7,36 @@ import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import { Cog8ToothIcon, - // CommandLineIcon, - Squares2X2Icon, - // eslint-disable-next-line sort-imports - PuzzlePieceIcon, DocumentTextIcon, PaperClipIcon, - QuestionMarkCircleIcon, } from '@heroicons/react/24/outline' import { Cog8ToothIcon as Cog8ToothSolidIcon, // CommandLineIcon as CommandLineSolidIcon, DocumentTextIcon as DocumentTextSolidIcon, } from '@heroicons/react/24/solid' -import Link from 'next/link' +import { RiApps2AddLine, RiInformation2Line } from '@remixicon/react' import s from './style.module.css' import classNames from '@/utils/classnames' import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets' -import type { RelatedApp, RelatedAppResponse } from '@/models/datasets' +import type { RelatedAppResponse } from '@/models/datasets' import AppSideBar from '@/app/components/app-sidebar' -import Divider from '@/app/components/base/divider' -import AppIcon from '@/app/components/base/app-icon' import Loading from '@/app/components/base/loading' -import FloatPopoverContainer from '@/app/components/base/float-popover-container' import DatasetDetailContext from '@/context/dataset-detail' import { DataSourceType } from '@/models/datasets' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { LanguagesSupported } from '@/i18n/language' import { useStore } from '@/app/components/app/store' -import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' -import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import { getLocaleOnClient } from '@/i18n' import { useAppContext } from '@/context/app-context' +import Tooltip from '@/app/components/base/tooltip' +import LinkedAppsPanel from '@/app/components/base/linked-apps-panel' export type IAppDetailLayoutProps = { children: React.ReactNode params: { datasetId: string } } -type ILikedItemProps = { - type?: 'plugin' | 'app' - appStatus?: boolean - detail: RelatedApp - isMobile: boolean -} - -const LikedItem = ({ - type = 'app', - detail, - isMobile, -}: ILikedItemProps) => { - return ( - -
- - {type === 'app' && ( - - {detail.mode === 'advanced-chat' && ( - - )} - {detail.mode === 'agent-chat' && ( - - )} - {detail.mode === 'chat' && ( - - )} - {detail.mode === 'completion' && ( - - )} - {detail.mode === 'workflow' && ( - - )} - - )} -
- {!isMobile &&
{detail?.name || '--'}
} - - ) -} - const TargetIcon = ({ className }: SVGProps) => { return @@ -117,65 +68,80 @@ const BookOpenIcon = ({ className }: SVGProps) => { type IExtraInfoProps = { isMobile: boolean relatedApps?: RelatedAppResponse + expand: boolean } -const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { +const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { const locale = getLocaleOnClient() const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) const { t } = useTranslation() + const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0 + const relatedAppsTotal = relatedApps?.data?.length || 0 + useEffect(() => { setShowTips(!isMobile) }, [isMobile, setShowTips]) - return
- - {(relatedApps?.data && relatedApps?.data?.length > 0) && ( + return
+ {hasRelatedApps && ( <> - {!isMobile &&
{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}
} + {!isMobile && ( + + } + > +
+ {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} + +
+
+ )} + {isMobile &&
- {relatedApps?.total || '--'} + {relatedAppsTotal || '--'}
} - {relatedApps?.data?.map((item, index) => ())} )} - {!relatedApps?.data?.length && ( - - + {!hasRelatedApps && !expand && ( + +
+ +
+
{t('common.datasetMenus.emptyTip')}
+ + + {t('common.datasetMenus.viewDoc')} +
} > -
-
-
- -
-
- -
-
-
{t('common.datasetMenus.emptyTip')}
- - - {t('common.datasetMenus.viewDoc')} - +
+ {t('common.datasetMenus.noRelatedApp')} +
- + )}
} @@ -235,7 +201,7 @@ const DatasetDetailLayout: FC = (props) => { }, [isMobile, setAppSiderbarExpand]) if (!datasetRes && !error) - return + return return (
@@ -246,7 +212,7 @@ const DatasetDetailLayout: FC = (props) => { desc={datasetRes?.description || '--'} isExternal={datasetRes?.provider === 'external'} navigation={navigation} - extraInfo={!isCurrentWorkspaceDatasetOperator ? mode => : undefined} + extraInfo={!isCurrentWorkspaceDatasetOperator ? mode => : undefined} iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'} />} = (props) => { dataset: datasetRes, mutateDatasetRes: () => mutateDatasetRes(), }}> -
{children}
+
{children}
) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index df314ddafe..3a65f1d30f 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -7,10 +7,10 @@ const Settings = async () => { const { t } = await translate(locale, 'dataset-settings') return ( -
+
-
{t('title')}
-
{t('desc')}
+
{t('title')}
+
{t('desc')}
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css index 0ee64b4fcd..516b124809 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css @@ -1,12 +1,3 @@ -.itemWrapper { - @apply flex items-center w-full h-10 rounded-lg hover:bg-gray-50 cursor-pointer; -} -.appInfo { - @apply truncate text-gray-700 text-sm font-normal; -} -.iconWrapper { - @apply relative w-6 h-6 rounded-lg; -} .statusPoint { @apply flex justify-center items-center absolute -right-0.5 -bottom-0.5 w-2.5 h-2.5 bg-white rounded; } diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx index a30521d998..a0edb1cd61 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -17,7 +17,6 @@ import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' import Button from '@/app/components/base/button' import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import SearchInput from '@/app/components/base/search-input' // Services import { fetchDatasetApiBaseUrl } from '@/service/datasets' @@ -29,6 +28,7 @@ import { useAppContext } from '@/context/app-context' import { useExternalApiPanel } from '@/context/external-api-panel-context' // eslint-disable-next-line import/order import { useQuery } from '@tanstack/react-query' +import Input from '@/app/components/base/input' const Container = () => { const { t } = useTranslation() @@ -81,17 +81,24 @@ const Container = () => { }, [currentWorkspace, router]) return ( -
-
+
+
setActiveTab(newActiveTab)} options={options} /> {activeTab === 'dataset' && ( -
+
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
+ +
+
+} diff --git a/web/app/components/base/linked-apps-panel/index.tsx b/web/app/components/base/linked-apps-panel/index.tsx new file mode 100644 index 0000000000..4320cb0fc6 --- /dev/null +++ b/web/app/components/base/linked-apps-panel/index.tsx @@ -0,0 +1,62 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Link from 'next/link' +import { useTranslation } from 'react-i18next' +import { RiArrowRightUpLine } from '@remixicon/react' +import cn from '@/utils/classnames' +import AppIcon from '@/app/components/base/app-icon' +import type { RelatedApp } from '@/models/datasets' + +type ILikedItemProps = { + appStatus?: boolean + detail: RelatedApp + isMobile: boolean +} + +const appTypeMap = { + 'chat': 'Chatbot', + 'completion': 'Completion', + 'agent-chat': 'Agent', + 'advanced-chat': 'Chatflow', + 'workflow': 'Workflow', +} + +const LikedItem = ({ + detail, + isMobile, +}: ILikedItemProps) => { + return ( + +
+
+ +
+ {!isMobile &&
{detail?.name || '--'}
} +
+
{appTypeMap[detail.mode]}
+ + + ) +} + +type Props = { + relatedApps: RelatedApp[] + isMobile: boolean +} + +const LinkedAppsPanel: FC = ({ + relatedApps, + isMobile, +}) => { + const { t } = useTranslation() + return ( +
+
{relatedApps.length || '--'} {t('common.datasetMenus.relatedApp')}
+ {relatedApps.map((item, index) => ( + + ))} +
+ ) +} +export default React.memo(LinkedAppsPanel) diff --git a/web/app/components/base/pagination/index.tsx b/web/app/components/base/pagination/index.tsx index b64c712425..c0cc9f86ec 100644 --- a/web/app/components/base/pagination/index.tsx +++ b/web/app/components/base/pagination/index.tsx @@ -8,7 +8,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import cn from '@/utils/classnames' -type Props = { +export type Props = { className?: string current: number onChange: (cur: number) => void diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 49acc81484..68c980ad09 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import { InputNumber } from '../input-number' import Tooltip from '@/app/components/base/tooltip' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' @@ -23,39 +24,44 @@ type Props = { const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, min = 0, max, value, enable, onChange, hasSwitch, onSwitchChange }) => { return (
-
-
+
+
{hasSwitch && ( { onSwitchChange?.(id, val) }} /> )} - {name} + {name} {!noTooltip && ( {tip}
} /> )} -
-
-
-
- { - const value = parseFloat(e.target.value) - if (value < min || value > max) - return - - onChange(id, value) - }} /> +
+
+ { + onChange(id, value) + }} + className='w-[72px]' + />
-
+
= ({ onChosen = () => { }, chosenConfig, chosenConfigWrapClassName, + className, }) => { return (
-
-
+
+
{icon}
-
{title}
-
{description}
+
{title}
+
{description}
{!noRadio && ( -
+
= ({ )}
{((isChosen && chosenConfig) || noRadio) && ( -
- {chosenConfig} +
+
+
+ {chosenConfig} +
)}
diff --git a/web/app/components/base/retry-button/index.tsx b/web/app/components/base/retry-button/index.tsx deleted file mode 100644 index 689827af7b..0000000000 --- a/web/app/components/base/retry-button/index.tsx +++ /dev/null @@ -1,85 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useEffect, useReducer } from 'react' -import { useTranslation } from 'react-i18next' -import useSWR from 'swr' -import s from './style.module.css' -import classNames from '@/utils/classnames' -import Divider from '@/app/components/base/divider' -import { getErrorDocs, retryErrorDocs } from '@/service/datasets' -import type { IndexingStatusResponse } from '@/models/datasets' - -const WarningIcon = () => - - - - -type Props = { - datasetId: string -} -type IIndexState = { - value: string -} -type ActionType = 'retry' | 'success' | 'error' - -type IAction = { - type: ActionType -} -const indexStateReducer = (state: IIndexState, action: IAction) => { - const actionMap = { - retry: 'retry', - success: 'success', - error: 'error', - } - - return { - ...state, - value: actionMap[action.type] || state.value, - } -} - -const RetryButton: FC = ({ datasetId }) => { - const { t } = useTranslation() - const [indexState, dispatch] = useReducer(indexStateReducer, { value: 'success' }) - const { data: errorDocs } = useSWR({ datasetId }, getErrorDocs) - - const onRetryErrorDocs = async () => { - dispatch({ type: 'retry' }) - const document_ids = errorDocs?.data.map((doc: IndexingStatusResponse) => doc.id) || [] - const res = await retryErrorDocs({ datasetId, document_ids }) - if (res.result === 'success') - dispatch({ type: 'success' }) - else - dispatch({ type: 'error' }) - } - - useEffect(() => { - if (errorDocs?.total === 0) - dispatch({ type: 'success' }) - else - dispatch({ type: 'error' }) - }, [errorDocs?.total]) - - if (indexState.value === 'success') - return null - - return ( -
- - - {errorDocs?.total} {t('dataset.docsFailedNotice')} - - - - {t('dataset.retry')} - -
- ) -} -export default RetryButton diff --git a/web/app/components/base/retry-button/style.module.css b/web/app/components/base/retry-button/style.module.css deleted file mode 100644 index 99a0947576..0000000000 --- a/web/app/components/base/retry-button/style.module.css +++ /dev/null @@ -1,4 +0,0 @@ -.retryBtn { - @apply inline-flex justify-center items-center content-center h-9 leading-5 rounded-lg px-4 py-2 text-base; - @apply border-solid border border-gray-200 text-gray-500 hover:bg-white hover:shadow-sm hover:border-gray-300; -} diff --git a/web/app/components/base/simple-pie-chart/index.tsx b/web/app/components/base/simple-pie-chart/index.tsx index 7de539cbb1..4b987ab42d 100644 --- a/web/app/components/base/simple-pie-chart/index.tsx +++ b/web/app/components/base/simple-pie-chart/index.tsx @@ -10,10 +10,11 @@ export type SimplePieChartProps = { fill?: string stroke?: string size?: number + animationDuration?: number className?: string } -const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', size = 12, className }: SimplePieChartProps) => { +const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', size = 12, animationDuration, className }: SimplePieChartProps) => { const option: EChartsOption = useMemo(() => ({ series: [ { @@ -34,7 +35,7 @@ const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', { type: 'pie', radius: '83%', - animationDuration: 600, + animationDuration: animationDuration ?? 600, data: [ { value: percentage, itemStyle: { color: fill } }, { value: 100 - percentage, itemStyle: { color: '#fff' } }, @@ -48,7 +49,7 @@ const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', cursor: 'default', }, ], - }), [stroke, fill, percentage]) + }), [stroke, fill, percentage, animationDuration]) return ( -export const SkeletonContanier: FC = (props) => { +export const SkeletonContainer: FC = (props) => { const { className, children, ...rest } = props return (
@@ -30,11 +30,14 @@ export const SkeletonRectangle: FC = (props) => { ) } -export const SkeletonPoint: FC = () => -
·
- +export const SkeletonPoint: FC = (props) => { + const { className, ...rest } = props + return ( +
·
+ ) +} /** Usage - * + * * * * diff --git a/web/app/components/base/switch/index.tsx b/web/app/components/base/switch/index.tsx index f61c6f46ff..8bf32b1311 100644 --- a/web/app/components/base/switch/index.tsx +++ b/web/app/components/base/switch/index.tsx @@ -64,4 +64,7 @@ const Switch = ({ onChange, size = 'md', defaultValue = false, disabled = false, ) } + +Switch.displayName = 'Switch' + export default React.memo(Switch) diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index b26d0c6438..ec6c1cee34 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -3,8 +3,8 @@ import type { ChangeEvent, FC, KeyboardEvent } from 'react' import { } from 'use-context-selector' import { useTranslation } from 'react-i18next' import AutosizeInput from 'react-18-input-autosize' +import { RiAddLine, RiCloseLine } from '@remixicon/react' import cn from '@/utils/classnames' -import { X } from '@/app/components/base/icons/src/vender/line/general' import { useToastContext } from '@/app/components/base/toast' type TagInputProps = { @@ -75,14 +75,14 @@ const TagInput: FC = ({ (items || []).map((item, index) => (
+ className={cn('flex items-center mr-1 mt-1 pl-1.5 pr-1 py-1 system-xs-regular text-text-secondary border border-divider-deep bg-components-badge-white-to-dark rounded-md')} + > {item} { !disableRemove && ( - handleRemove(index)} - /> +
handleRemove(index)}> + +
) }
@@ -90,24 +90,27 @@ const TagInput: FC = ({ } { !disableAdd && ( - setFocused(true)} - onBlur={handleBlur} - value={value} - onChange={(e: ChangeEvent) => { - setValue(e.target.value) - }} - onKeyDown={handleKeyDown} - placeholder={t(placeholder || (isSpecialMode ? 'common.model.params.stop_sequencesPlaceholder' : 'datasetDocuments.segment.addKeyWord'))} - /> +
+ {!isSpecialMode && !focused && } + setFocused(true)} + onBlur={handleBlur} + value={value} + onChange={(e: ChangeEvent) => { + setValue(e.target.value) + }} + onKeyDown={handleKeyDown} + placeholder={t(placeholder || (isSpecialMode ? 'common.model.params.stop_sequencesPlaceholder' : 'datasetDocuments.segment.addKeyWord'))} + /> +
) }
diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index b9a6de9fe5..ba7d8af518 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -21,6 +21,7 @@ export type IToastProps = { children?: ReactNode onClose?: () => void className?: string + customComponent?: ReactNode } type IToastContext = { notify: (props: IToastProps) => void @@ -35,6 +36,7 @@ const Toast = ({ message, children, className, + customComponent, }: IToastProps) => { const { close } = useToastContext() // sometimes message is react node array. Not handle it. @@ -49,8 +51,7 @@ const Toast = ({ 'top-0', 'right-0', )}> - -
-
{message}
+
+
+
{message}
+ {customComponent} +
{children &&
{children}
}
- +
@@ -117,7 +121,8 @@ Toast.notify = ({ message, duration, className, -}: Pick) => { + customComponent, +}: Pick) => { const defaultDuring = (type === 'success' || type === 'info') ? 3000 : 6000 if (typeof window === 'object') { const holder = document.createElement('div') @@ -133,7 +138,7 @@ Toast.notify = ({ } }, }}> - + , ) document.body.appendChild(holder) diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index 8ec3cd8c7a..65b5a99077 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -14,6 +14,7 @@ export type TooltipProps = { popupContent?: React.ReactNode children?: React.ReactNode popupClassName?: string + noDecoration?: boolean offset?: OffsetOptions needsDelay?: boolean asChild?: boolean @@ -27,6 +28,7 @@ const Tooltip: FC = ({ popupContent, children, popupClassName, + noDecoration, offset, asChild = true, needsDelay = false, @@ -96,7 +98,7 @@ const Tooltip: FC = ({ > {popupContent && (
triggerMethod === 'hover' && setHoverPopup()} diff --git a/web/app/components/billing/priority-label/index.tsx b/web/app/components/billing/priority-label/index.tsx index 36338cf4a8..6ecac4a79e 100644 --- a/web/app/components/billing/priority-label/index.tsx +++ b/web/app/components/billing/priority-label/index.tsx @@ -4,6 +4,7 @@ import { DocumentProcessingPriority, Plan, } from '../type' +import cn from '@/utils/classnames' import { useProviderContext } from '@/context/provider-context' import { ZapFast, @@ -11,7 +12,11 @@ import { } from '@/app/components/base/icons/src/vender/solid/general' import Tooltip from '@/app/components/base/tooltip' -const PriorityLabel = () => { +type PriorityLabelProps = { + className?: string +} + +const PriorityLabel = ({ className }: PriorityLabelProps) => { const { t } = useTranslation() const { plan } = useProviderContext() @@ -37,18 +42,18 @@ const PriorityLabel = () => { }
}> - + { plan.type === Plan.professional && ( - + ) } { (plan.type === Plan.team || plan.type === Plan.enterprise) && ( - + ) } {t(`billing.plansCommon.priority.${priority}`)} diff --git a/web/app/components/datasets/chunk.tsx b/web/app/components/datasets/chunk.tsx new file mode 100644 index 0000000000..bf2835dbdb --- /dev/null +++ b/web/app/components/datasets/chunk.tsx @@ -0,0 +1,54 @@ +import type { FC, PropsWithChildren } from 'react' +import { SelectionMod } from '../base/icons/src/public/knowledge' +import type { QA } from '@/models/datasets' + +export type ChunkLabelProps = { + label: string + characterCount: number +} + +export const ChunkLabel: FC = (props) => { + const { label, characterCount } = props + return
+ +

+ {label} + + + · + + + {`${characterCount} characters`} +

+
+} + +export type ChunkContainerProps = ChunkLabelProps & PropsWithChildren + +export const ChunkContainer: FC = (props) => { + const { label, characterCount, children } = props + return
+ +
+ {children} +
+
+} + +export type QAPreviewProps = { + qa: QA +} + +export const QAPreview: FC = (props) => { + const { qa } = props + return
+
+ +

{qa.question}

+
+
+ +

{qa.answer}

+
+
+} diff --git a/web/app/components/datasets/common/chunking-mode-label.tsx b/web/app/components/datasets/common/chunking-mode-label.tsx new file mode 100644 index 0000000000..7c6e924009 --- /dev/null +++ b/web/app/components/datasets/common/chunking-mode-label.tsx @@ -0,0 +1,29 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import Badge from '@/app/components/base/badge' +import { GeneralType, ParentChildType } from '@/app/components/base/icons/src/public/knowledge' + +type Props = { + isGeneralMode: boolean + isQAMode: boolean +} + +const ChunkingModeLabel: FC = ({ + isGeneralMode, + isQAMode, +}) => { + const { t } = useTranslation() + const TypeIcon = isGeneralMode ? GeneralType : ParentChildType + + return ( + +
+ + {isGeneralMode ? `${t('dataset.chunkingMode.general')}${isQAMode ? ' · QA' : ''}` : t('dataset.chunkingMode.parentChild')} +
+
+ ) +} +export default React.memo(ChunkingModeLabel) diff --git a/web/app/components/datasets/common/document-file-icon.tsx b/web/app/components/datasets/common/document-file-icon.tsx new file mode 100644 index 0000000000..5842cbbc7c --- /dev/null +++ b/web/app/components/datasets/common/document-file-icon.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import FileTypeIcon from '../../base/file-uploader/file-type-icon' +import type { FileAppearanceType } from '@/app/components/base/file-uploader/types' +import { FileAppearanceTypeEnum } from '@/app/components/base/file-uploader/types' + +const extendToFileTypeMap: { [key: string]: FileAppearanceType } = { + pdf: FileAppearanceTypeEnum.pdf, + json: FileAppearanceTypeEnum.document, + html: FileAppearanceTypeEnum.document, + txt: FileAppearanceTypeEnum.document, + markdown: FileAppearanceTypeEnum.markdown, + md: FileAppearanceTypeEnum.markdown, + xlsx: FileAppearanceTypeEnum.excel, + xls: FileAppearanceTypeEnum.excel, + csv: FileAppearanceTypeEnum.excel, + doc: FileAppearanceTypeEnum.word, + docx: FileAppearanceTypeEnum.word, +} + +type Props = { + extension?: string + name?: string + size?: 'sm' | 'lg' | 'md' + className?: string +} + +const DocumentFileIcon: FC = ({ + extension, + name, + size = 'md', + className, +}) => { + const localExtension = extension?.toLowerCase() || name?.split('.')?.pop()?.toLowerCase() + return ( + + ) +} +export default React.memo(DocumentFileIcon) diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx new file mode 100644 index 0000000000..3e320d7507 --- /dev/null +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -0,0 +1,42 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import FileIcon from '../document-file-icon' +import cn from '@/utils/classnames' +import type { DocumentItem } from '@/models/datasets' + +type Props = { + className?: string + list: DocumentItem[] + onChange: (value: DocumentItem) => void +} + +const DocumentList: FC = ({ + className, + list, + onChange, +}) => { + const handleChange = useCallback((item: DocumentItem) => { + return () => onChange(item) + }, [onChange]) + + return ( +
+ {list.map((item) => { + const { id, name, extension } = item + return ( +
+ +
{name}
+
+ ) + })} +
+ ) +} + +export default React.memo(DocumentList) diff --git a/web/app/components/datasets/common/document-picker/index.tsx b/web/app/components/datasets/common/document-picker/index.tsx new file mode 100644 index 0000000000..30690fca00 --- /dev/null +++ b/web/app/components/datasets/common/document-picker/index.tsx @@ -0,0 +1,118 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { useBoolean } from 'ahooks' +import { RiArrowDownSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import FileIcon from '../document-file-icon' +import DocumentList from './document-list' +import type { DocumentItem, ParentMode, SimpleDocumentDetail } from '@/models/datasets' +import { ProcessMode } from '@/models/datasets' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import SearchInput from '@/app/components/base/search-input' +import { GeneralType, ParentChildType } from '@/app/components/base/icons/src/public/knowledge' +import { useDocumentList } from '@/service/knowledge/use-document' +import Loading from '@/app/components/base/loading' + +type Props = { + datasetId: string + value: { + name?: string + extension?: string + processMode?: ProcessMode + parentMode?: ParentMode + } + onChange: (value: SimpleDocumentDetail) => void +} + +const DocumentPicker: FC = ({ + datasetId, + value, + onChange, +}) => { + const { t } = useTranslation() + const { + name, + extension, + processMode, + parentMode, + } = value + const [query, setQuery] = useState('') + + const { data } = useDocumentList({ + datasetId, + query: { + keyword: query, + page: 1, + limit: 20, + }, + }) + const documentsList = data?.data + const isParentChild = processMode === ProcessMode.parentChild + const TypeIcon = isParentChild ? ParentChildType : GeneralType + + const [open, { + set: setOpen, + toggle: togglePopup, + }] = useBoolean(false) + const ArrowIcon = RiArrowDownSLine + + const handleChange = useCallback(({ id }: DocumentItem) => { + onChange(documentsList?.find(item => item.id === id) as SimpleDocumentDetail) + setOpen(false) + }, [documentsList, onChange, setOpen]) + + return ( + + +
+ +
+
+ {name || '--'} + +
+
+ + + {isParentChild ? t('dataset.chunkingMode.parentChild') : t('dataset.chunkingMode.general')} + {isParentChild && ` · ${!parentMode ? '--' : parentMode === 'paragraph' ? t('dataset.parentMode.paragraph') : t('dataset.parentMode.fullDoc')}`} + +
+
+
+
+ +
+ + {documentsList + ? ( + ({ + id: d.id, + name: d.name, + extension: d.data_source_detail_dict?.upload_file?.extension || '', + }))} + onChange={handleChange} + /> + ) + : (
+ +
)} +
+ +
+
+ ) +} +export default React.memo(DocumentPicker) diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx new file mode 100644 index 0000000000..2a35b75471 --- /dev/null +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx @@ -0,0 +1,82 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useBoolean } from 'ahooks' +import { RiArrowDownSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import FileIcon from '../document-file-icon' +import DocumentList from './document-list' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import Loading from '@/app/components/base/loading' +import type { DocumentItem } from '@/models/datasets' + +type Props = { + className?: string + value: DocumentItem + files: DocumentItem[] + onChange: (value: DocumentItem) => void +} + +const PreviewDocumentPicker: FC = ({ + className, + value, + files, + onChange, +}) => { + const { t } = useTranslation() + const { name, extension } = value + + const [open, { + set: setOpen, + toggle: togglePopup, + }] = useBoolean(false) + const ArrowIcon = RiArrowDownSLine + + const handleChange = useCallback((item: DocumentItem) => { + onChange(item) + setOpen(false) + }, [onChange, setOpen]) + + return ( + + +
+ +
+
+ {name || '--'} + +
+
+
+
+ +
+ {files?.length > 1 &&
{t('dataset.preprocessDocument', { num: files.length })}
} + {files?.length > 0 + ? ( + + ) + : (
+ +
)} +
+ +
+
+ ) +} +export default React.memo(PreviewDocumentPicker) diff --git a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx new file mode 100644 index 0000000000..b687c004e5 --- /dev/null +++ b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx @@ -0,0 +1,38 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import StatusWithAction from './status-with-action' +import { useAutoDisabledDocuments, useDocumentEnable, useInvalidDisabledDocument } from '@/service/knowledge/use-document' +import Toast from '@/app/components/base/toast' +type Props = { + datasetId: string +} + +const AutoDisabledDocument: FC = ({ + datasetId, +}) => { + const { t } = useTranslation() + const { data, isLoading } = useAutoDisabledDocuments(datasetId) + const invalidDisabledDocument = useInvalidDisabledDocument() + const documentIds = data?.document_ids + const hasDisabledDocument = documentIds && documentIds.length > 0 + const { mutateAsync: enableDocument } = useDocumentEnable() + const handleEnableDocuments = useCallback(async () => { + await enableDocument({ datasetId, documentIds }) + invalidDisabledDocument() + Toast.notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + }, []) + if (!hasDisabledDocument || isLoading) + return null + + return ( + + ) +} +export default React.memo(AutoDisabledDocument) diff --git a/web/app/components/datasets/common/document-status-with-action/index-failed.tsx b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx new file mode 100644 index 0000000000..37311768b9 --- /dev/null +++ b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx @@ -0,0 +1,69 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useReducer } from 'react' +import { useTranslation } from 'react-i18next' +import useSWR from 'swr' +import StatusWithAction from './status-with-action' +import { getErrorDocs, retryErrorDocs } from '@/service/datasets' +import type { IndexingStatusResponse } from '@/models/datasets' + +type Props = { + datasetId: string +} +type IIndexState = { + value: string +} +type ActionType = 'retry' | 'success' | 'error' + +type IAction = { + type: ActionType +} +const indexStateReducer = (state: IIndexState, action: IAction) => { + const actionMap = { + retry: 'retry', + success: 'success', + error: 'error', + } + + return { + ...state, + value: actionMap[action.type] || state.value, + } +} + +const RetryButton: FC = ({ datasetId }) => { + const { t } = useTranslation() + const [indexState, dispatch] = useReducer(indexStateReducer, { value: 'success' }) + const { data: errorDocs, isLoading } = useSWR({ datasetId }, getErrorDocs) + + const onRetryErrorDocs = async () => { + dispatch({ type: 'retry' }) + const document_ids = errorDocs?.data.map((doc: IndexingStatusResponse) => doc.id) || [] + const res = await retryErrorDocs({ datasetId, document_ids }) + if (res.result === 'success') + dispatch({ type: 'success' }) + else + dispatch({ type: 'error' }) + } + + useEffect(() => { + if (errorDocs?.total === 0) + dispatch({ type: 'success' }) + else + dispatch({ type: 'error' }) + }, [errorDocs?.total]) + + if (isLoading || indexState.value === 'success') + return null + + return ( + { }} + /> + ) +} +export default RetryButton diff --git a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx new file mode 100644 index 0000000000..a8da9bf6cc --- /dev/null +++ b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx @@ -0,0 +1,65 @@ +'use client' +import { RiAlertFill, RiCheckboxCircleFill, RiErrorWarningFill, RiInformation2Fill } from '@remixicon/react' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import Divider from '@/app/components/base/divider' + +type Status = 'success' | 'error' | 'warning' | 'info' +type Props = { + type?: Status + description: string + actionText: string + onAction: () => void + disabled?: boolean +} + +const IconMap = { + success: { + Icon: RiCheckboxCircleFill, + color: 'text-text-success', + }, + error: { + Icon: RiErrorWarningFill, + color: 'text-text-destructive', + }, + warning: { + Icon: RiAlertFill, + color: 'text-text-warning-secondary', + }, + info: { + Icon: RiInformation2Fill, + color: 'text-text-accent', + }, +} + +const getIcon = (type: Status) => { + return IconMap[type] +} + +const StatusAction: FC = ({ + type = 'info', + description, + actionText, + onAction, + disabled, +}) => { + const { Icon, color } = getIcon(type) + return ( +
+
+
+ +
{description}
+ +
{actionText}
+
+
+ ) +} +export default React.memo(StatusAction) diff --git a/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx b/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx index f3da67b92c..9236858ae4 100644 --- a/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx @@ -2,10 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' import RetrievalParamConfig from '../retrieval-param-config' +import { OptionCard } from '../../create/step-two/option-card' +import { retrievalIcon } from '../../create/icons' import { RETRIEVE_METHOD } from '@/types/app' -import RadioCard from '@/app/components/base/radio-card' -import { HighPriority } from '@/app/components/base/icons/src/vender/solid/arrows' import type { RetrievalConfig } from '@/types/app' type Props = { @@ -21,19 +22,17 @@ const EconomicalRetrievalMethodConfig: FC = ({ return (
- } + } title={t('dataset.retrieval.invertedIndex.title')} - description={t('dataset.retrieval.invertedIndex.description')} - noRadio - chosenConfig={ - - } - /> + description={t('dataset.retrieval.invertedIndex.description')} isActive + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + +
) } diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 20d93568ad..9ab157571b 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -2,12 +2,13 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' import RetrievalParamConfig from '../retrieval-param-config' +import { OptionCard } from '../../create/step-two/option-card' +import Effect from '../../create/assets/option-card-effect-purple.svg' +import { retrievalIcon } from '../../create/icons' import type { RetrievalConfig } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app' -import RadioCard from '@/app/components/base/radio-card' -import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/vender/solid/development' -import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files' import { useProviderContext } from '@/context/provider-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -16,6 +17,7 @@ import { RerankingModeEnum, WeightedScoreEnum, } from '@/models/datasets' +import Badge from '@/app/components/base/badge' type Props = { value: RetrievalConfig @@ -56,67 +58,72 @@ const RetrievalMethodConfig: FC = ({ return (
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( - } + } title={t('dataset.retrieval.semantic_search.title')} description={t('dataset.retrieval.semantic_search.description')} - isChosen={value.search_method === RETRIEVE_METHOD.semantic} - onChosen={() => onChange({ + isActive={ + value.search_method === RETRIEVE_METHOD.semantic + } + onSwitched={() => onChange({ ...value, search_method: RETRIEVE_METHOD.semantic, })} - chosenConfig={ - - } - /> + effectImg={Effect.src} + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + + )} {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( - } + } title={t('dataset.retrieval.full_text_search.title')} description={t('dataset.retrieval.full_text_search.description')} - isChosen={value.search_method === RETRIEVE_METHOD.fullText} - onChosen={() => onChange({ + isActive={ + value.search_method === RETRIEVE_METHOD.fullText + } + onSwitched={() => onChange({ ...value, search_method: RETRIEVE_METHOD.fullText, })} - chosenConfig={ - - } - /> + effectImg={Effect.src} + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + + )} {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( - } + } title={
{t('dataset.retrieval.hybrid_search.title')}
-
{t('dataset.retrieval.hybrid_search.recommend')}
+
} - description={t('dataset.retrieval.hybrid_search.description')} - isChosen={value.search_method === RETRIEVE_METHOD.hybrid} - onChosen={() => onChange({ + description={t('dataset.retrieval.hybrid_search.description')} isActive={ + value.search_method === RETRIEVE_METHOD.hybrid + } + onSwitched={() => onChange({ ...value, search_method: RETRIEVE_METHOD.hybrid, reranking_enable: true, })} - chosenConfig={ - - } - /> + effectImg={Effect.src} + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + +
)}
) diff --git a/web/app/components/datasets/common/retrieval-method-info/index.tsx b/web/app/components/datasets/common/retrieval-method-info/index.tsx index 7d9b999c53..fc3020d4a9 100644 --- a/web/app/components/datasets/common/retrieval-method-info/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-info/index.tsx @@ -2,12 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' +import { retrievalIcon } from '../../create/icons' import type { RetrievalConfig } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app' import RadioCard from '@/app/components/base/radio-card' -import { HighPriority } from '@/app/components/base/icons/src/vender/solid/arrows' -import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/vender/solid/development' -import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files' type Props = { value: RetrievalConfig @@ -15,11 +14,12 @@ type Props = { export const getIcon = (type: RETRIEVE_METHOD) => { return ({ - [RETRIEVE_METHOD.semantic]: Semantic, - [RETRIEVE_METHOD.fullText]: FileSearch02, - [RETRIEVE_METHOD.hybrid]: PatternRecognition, - [RETRIEVE_METHOD.invertedIndex]: HighPriority, - })[type] || FileSearch02 + [RETRIEVE_METHOD.semantic]: retrievalIcon.vector, + [RETRIEVE_METHOD.fullText]: retrievalIcon.fullText, + [RETRIEVE_METHOD.hybrid]: retrievalIcon.hybrid, + [RETRIEVE_METHOD.invertedIndex]: retrievalIcon.vector, + [RETRIEVE_METHOD.keywordSearch]: retrievalIcon.vector, + })[type] || retrievalIcon.vector } const EconomicalRetrievalMethodConfig: FC = ({ @@ -28,11 +28,11 @@ const EconomicalRetrievalMethodConfig: FC = ({ }) => { const { t } = useTranslation() const type = value.search_method - const Icon = getIcon(type) + const icon = return (
} + icon={icon} title={t(`dataset.retrieval.${type}.title`)} description={t(`dataset.retrieval.${type}.description`)} noRadio diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 9d48d56a8d..5136ac1659 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -3,6 +3,9 @@ import type { FC } from 'react' import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' +import ProgressIndicator from '../../create/assets/progress-indicator.svg' +import Reranking from '../../create/assets/rerank.svg' import cn from '@/utils/classnames' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' @@ -20,6 +23,7 @@ import { } from '@/models/datasets' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import Toast from '@/app/components/base/toast' +import RadioCard from '@/app/components/base/radio-card' type Props = { type: RETRIEVE_METHOD @@ -116,7 +120,7 @@ const RetrievalParamConfig: FC = ({
{!isEconomical && !isHybridSearch && (
-
+
{canToggleRerankModalEnable && (
= ({
)}
- {t('common.modelProvider.rerankModel.key')} + {t('common.modelProvider.rerankModel.key')} {t('common.modelProvider.rerankModel.tip')}
@@ -163,7 +167,7 @@ const RetrievalParamConfig: FC = ({ )} { !isHybridSearch && ( -
+
= ({ { isHybridSearch && ( <> -
+
{ rerankingModeOptions.map(option => ( -
handleChangeRerankMode(option.value)} - > -
{option.label}
- {option.tips}
} - triggerClassName='ml-0.5 w-3.5 h-3.5' - /> -
+ isChosen={value.reranking_mode === option.value} + onChosen={() => handleChangeRerankMode(option.value)} + icon={} + title={option.label} + description={option.tips} + className='flex-1' + /> )) }
diff --git a/web/app/components/datasets/create/assets/family-mod.svg b/web/app/components/datasets/create/assets/family-mod.svg new file mode 100644 index 0000000000..b1c4e6f566 --- /dev/null +++ b/web/app/components/datasets/create/assets/family-mod.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/file-list-3-fill.svg b/web/app/components/datasets/create/assets/file-list-3-fill.svg new file mode 100644 index 0000000000..a4e6c4da97 --- /dev/null +++ b/web/app/components/datasets/create/assets/file-list-3-fill.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/datasets/create/assets/gold.svg b/web/app/components/datasets/create/assets/gold.svg new file mode 100644 index 0000000000..b48ac0eae5 --- /dev/null +++ b/web/app/components/datasets/create/assets/gold.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/note-mod.svg b/web/app/components/datasets/create/assets/note-mod.svg new file mode 100644 index 0000000000..b9e81f6bd5 --- /dev/null +++ b/web/app/components/datasets/create/assets/note-mod.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/datasets/create/assets/option-card-effect-blue.svg b/web/app/components/datasets/create/assets/option-card-effect-blue.svg new file mode 100644 index 0000000000..00a8afad8b --- /dev/null +++ b/web/app/components/datasets/create/assets/option-card-effect-blue.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/option-card-effect-orange.svg b/web/app/components/datasets/create/assets/option-card-effect-orange.svg new file mode 100644 index 0000000000..d833764f0c --- /dev/null +++ b/web/app/components/datasets/create/assets/option-card-effect-orange.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/option-card-effect-purple.svg b/web/app/components/datasets/create/assets/option-card-effect-purple.svg new file mode 100644 index 0000000000..a7857f8e57 --- /dev/null +++ b/web/app/components/datasets/create/assets/option-card-effect-purple.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/pattern-recognition-mod.svg b/web/app/components/datasets/create/assets/pattern-recognition-mod.svg new file mode 100644 index 0000000000..1083e888ed --- /dev/null +++ b/web/app/components/datasets/create/assets/pattern-recognition-mod.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/piggy-bank-mod.svg b/web/app/components/datasets/create/assets/piggy-bank-mod.svg new file mode 100644 index 0000000000..b1120ad9a9 --- /dev/null +++ b/web/app/components/datasets/create/assets/piggy-bank-mod.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/progress-indicator.svg b/web/app/components/datasets/create/assets/progress-indicator.svg new file mode 100644 index 0000000000..3c99713636 --- /dev/null +++ b/web/app/components/datasets/create/assets/progress-indicator.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/web/app/components/datasets/create/assets/rerank.svg b/web/app/components/datasets/create/assets/rerank.svg new file mode 100644 index 0000000000..409b52e6e2 --- /dev/null +++ b/web/app/components/datasets/create/assets/rerank.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/research-mod.svg b/web/app/components/datasets/create/assets/research-mod.svg new file mode 100644 index 0000000000..1f0bb34233 --- /dev/null +++ b/web/app/components/datasets/create/assets/research-mod.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/selection-mod.svg b/web/app/components/datasets/create/assets/selection-mod.svg new file mode 100644 index 0000000000..2d0dd3b5f7 --- /dev/null +++ b/web/app/components/datasets/create/assets/selection-mod.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/setting-gear-mod.svg b/web/app/components/datasets/create/assets/setting-gear-mod.svg new file mode 100644 index 0000000000..c782caade8 --- /dev/null +++ b/web/app/components/datasets/create/assets/setting-gear-mod.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/embedding-process/index.module.css b/web/app/components/datasets/create/embedding-process/index.module.css index 1ebb006b54..f2ab4d85a2 100644 --- a/web/app/components/datasets/create/embedding-process/index.module.css +++ b/web/app/components/datasets/create/embedding-process/index.module.css @@ -14,24 +14,7 @@ border-radius: 6px; overflow: hidden; } -.sourceItem.error { - background: #FEE4E2; -} -.sourceItem.success { - background: #D1FADF; -} -.progressbar { - position: absolute; - top: 0; - left: 0; - height: 100%; - background-color: #B2CCFF; -} -.sourceItem .info { - display: flex; - align-items: center; - z-index: 1; -} + .sourceItem .info .name { font-weight: 500; font-size: 12px; @@ -55,13 +38,6 @@ color: #05603A; } - -.cost { - @apply flex justify-between items-center text-xs text-gray-700; -} -.embeddingStatus { - @apply flex items-center justify-between text-gray-900 font-medium text-sm mr-2; -} .commonIcon { @apply w-3 h-3 mr-1 inline-block align-middle; } @@ -81,35 +57,33 @@ @apply text-xs font-medium; } -.fileIcon { - @apply w-4 h-4 mr-1 bg-center bg-no-repeat; +.unknownFileIcon { background-image: url(../assets/unknown.svg); - background-size: 16px; } -.fileIcon.csv { +.csv { background-image: url(../assets/csv.svg); } -.fileIcon.docx { +.docx { background-image: url(../assets/docx.svg); } -.fileIcon.xlsx, -.fileIcon.xls { +.xlsx, +.xls { background-image: url(../assets/xlsx.svg); } -.fileIcon.pdf { +.pdf { background-image: url(../assets/pdf.svg); } -.fileIcon.html, -.fileIcon.htm { +.html, +.htm { background-image: url(../assets/html.svg); } -.fileIcon.md, -.fileIcon.markdown { +.md, +.markdown { background-image: url(../assets/md.svg); } -.fileIcon.txt { +.txt { background-image: url(../assets/txt.svg); } -.fileIcon.json { +.json { background-image: url(../assets/json.svg); } diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 7786582085..201333ffce 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -6,32 +6,44 @@ import { useTranslation } from 'react-i18next' import { omit } from 'lodash-es' import { ArrowRightIcon } from '@heroicons/react/24/solid' import { + RiCheckboxCircleFill, RiErrorWarningFill, + RiLoader2Fill, + RiTerminalBoxLine, } from '@remixicon/react' -import s from './index.module.css' +import Image from 'next/image' +import { indexMethodIcon, retrievalIcon } from '../icons' +import { IndexingType } from '../step-two' +import DocumentFileIcon from '../../common/document-file-icon' import cn from '@/utils/classnames' import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' import Button from '@/app/components/base/button' import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' -import { DataSourceType } from '@/models/datasets' +import { DataSourceType, ProcessMode } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' import { Plan } from '@/app/components/billing/type' import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { useProviderContext } from '@/context/provider-context' -import Tooltip from '@/app/components/base/tooltip' import { sleep } from '@/utils' +import { RETRIEVE_METHOD } from '@/types/app' +import Tooltip from '@/app/components/base/tooltip' type Props = { datasetId: string batchId: string documents?: FullDocumentDetail[] indexingType?: string + retrievalMethod?: string } -const RuleDetail: FC<{ sourceData?: ProcessRuleResponse }> = ({ sourceData }) => { +const RuleDetail: FC<{ + sourceData?: ProcessRuleResponse + indexingType?: string + retrievalMethod?: string +}> = ({ sourceData, indexingType, retrievalMethod }) => { const { t } = useTranslation() const segmentationRuleMap = { @@ -51,29 +63,47 @@ const RuleDetail: FC<{ sourceData?: ProcessRuleResponse }> = ({ sourceData }) => return t('datasetCreation.stepTwo.removeStopwords') } + const isNumber = (value: unknown) => { + return typeof value === 'number' + } + const getValue = useCallback((field: string) => { let value: string | number | undefined = '-' + const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens) + ? sourceData.rules.segmentation.max_tokens + : value + const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens) + ? sourceData.rules.subchunk_segmentation.max_tokens + : value switch (field) { case 'mode': - value = sourceData?.mode === 'automatic' ? (t('datasetDocuments.embedding.automatic') as string) : (t('datasetDocuments.embedding.custom') as string) + value = !sourceData?.mode + ? value + : sourceData.mode === ProcessMode.general + ? (t('datasetDocuments.embedding.custom') as string) + : `${t('datasetDocuments.embedding.hierarchical')} · ${sourceData?.rules?.parent_mode === 'paragraph' + ? t('dataset.parentMode.paragraph') + : t('dataset.parentMode.fullDoc')}` break case 'segmentLength': - value = sourceData?.rules?.segmentation?.max_tokens + value = !sourceData?.mode + ? value + : sourceData.mode === ProcessMode.general + ? maxTokens + : `${t('datasetDocuments.embedding.parentMaxTokens')} ${maxTokens}; ${t('datasetDocuments.embedding.childMaxTokens')} ${childMaxTokens}` break default: - value = sourceData?.mode === 'automatic' - ? (t('datasetDocuments.embedding.automatic') as string) - // eslint-disable-next-line array-callback-return - : sourceData?.rules?.pre_processing_rules?.map((rule) => { - if (rule.enabled) - return getRuleName(rule.id) - }).filter(Boolean).join(';') + value = !sourceData?.mode + ? value + : sourceData?.rules?.pre_processing_rules?.filter(rule => + rule.enabled).map(rule => getRuleName(rule.id)).join(',') break } return value + // eslint-disable-next-line react-hooks/exhaustive-deps }, [sourceData]) - return
+ return
{Object.keys(segmentationRuleMap).map((field) => { return = ({ sourceData }) => displayedValue={String(getValue(field))} /> })} + + } + /> + + } + />
} -const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], indexingType }) => { +const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], indexingType, retrievalMethod }) => { const { t } = useTranslation() const { enableBilling, plan } = useProviderContext() @@ -127,6 +190,7 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index } useEffect(() => { + setIsStopQuery(false) startQueryStatus() return () => { stopQueryStatus() @@ -146,6 +210,9 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index const navToDocumentList = () => { router.push(`/datasets/${datasetId}/documents`) } + const navToApiDocs = () => { + router.push('/datasets?category=api') + } const isEmbedding = useMemo(() => { return indexingStatusBatchDetail.some(indexingStatusDetail => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || '')) @@ -177,13 +244,17 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index return doc?.data_source_info.notion_page_icon } - const isSourceEmbedding = (detail: IndexingStatusResponse) => ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '') + const isSourceEmbedding = (detail: IndexingStatusResponse) => + ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '') return ( <> -
-
- {isEmbedding && t('datasetDocuments.embedding.processing')} +
+
+ {isEmbedding &&
+ + {t('datasetDocuments.embedding.processing')} +
} {isEmbeddingCompleted && t('datasetDocuments.embedding.completed')}
@@ -200,69 +271,80 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index
) } -
+
{indexingStatusBatchDetail.map(indexingStatusDetail => (
{isSourceEmbedding(indexingStatusDetail) && ( -
+
)} -
+
{getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && ( -
+ //
+ )} {getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && ( )} -
{getSourceName(indexingStatusDetail.id)}
- { - enableBilling && ( - - ) - } -
-
+
+
+ {getSourceName(indexingStatusDetail.id)} +
+ { + enableBilling && ( + + ) + } +
{isSourceEmbedding(indexingStatusDetail) && ( -
{`${getSourcePercent(indexingStatusDetail)}%`}
+
{`${getSourcePercent(indexingStatusDetail)}%`}
)} - {indexingStatusDetail.indexing_status === 'error' && indexingStatusDetail.error && ( + {indexingStatusDetail.indexing_status === 'error' && ( - {indexingStatusDetail.error} -
- )} + popupClassName='px-4 py-[14px] max-w-60 text-sm leading-4 text-text-secondary border-[0.5px] border-components-panel-border rounded-xl' + offset={4} + popupContent={indexingStatusDetail.error} > -
- Error - -
+ + + )} - {indexingStatusDetail.indexing_status === 'error' && !indexingStatusDetail.error && ( -
- Error -
- )} {indexingStatusDetail.indexing_status === 'completed' && ( -
100%
+ )}
))}
- -
+
+ +
+
diff --git a/web/app/components/datasets/create/file-preview/index.module.css b/web/app/components/datasets/create/file-preview/index.module.css index d87522e6d0..929002e1e2 100644 --- a/web/app/components/datasets/create/file-preview/index.module.css +++ b/web/app/components/datasets/create/file-preview/index.module.css @@ -1,6 +1,6 @@ .filePreview { @apply flex flex-col border-l border-gray-200 shrink-0; - width: 528px; + width: 100%; background-color: #fcfcfd; } @@ -48,5 +48,6 @@ } .fileContent { white-space: pre-line; + word-break: break-all; } \ No newline at end of file diff --git a/web/app/components/datasets/create/file-preview/index.tsx b/web/app/components/datasets/create/file-preview/index.tsx index e20af64386..cb1f1d6908 100644 --- a/web/app/components/datasets/create/file-preview/index.tsx +++ b/web/app/components/datasets/create/file-preview/index.tsx @@ -44,7 +44,7 @@ const FilePreview = ({ }, [file]) return ( -
+
{t('datasetCreation.stepOne.filePreview')} @@ -59,7 +59,7 @@ const FilePreview = ({
{loading &&
} {!loading && ( -
{previewContent}
+
{previewContent}
)}
diff --git a/web/app/components/datasets/create/file-uploader/index.module.css b/web/app/components/datasets/create/file-uploader/index.module.css index bf5b7dcaf5..7d29f2ef9c 100644 --- a/web/app/components/datasets/create/file-uploader/index.module.css +++ b/web/app/components/datasets/create/file-uploader/index.module.css @@ -1,68 +1,3 @@ -.fileUploader { - @apply mb-6; -} - -.fileUploader .title { - @apply mb-2; - font-weight: 500; - font-size: 16px; - line-height: 24px; - color: #344054; -} - -.fileUploader .tip { - font-weight: 400; - font-size: 12px; - line-height: 18px; - color: #667085; -} - -.uploader { - @apply relative box-border flex justify-center items-center mb-2 p-3; - flex-direction: column; - max-width: 640px; - min-height: 80px; - background: #F9FAFB; - border: 1px dashed #EAECF0; - border-radius: 12px; - font-weight: 400; - font-size: 14px; - line-height: 20px; - color: #667085; -} - -.uploader.dragging { - background: #F5F8FF; - border: 1px dashed #B2CCFF; -} - -.uploader .draggingCover { - position: absolute; - top: 0; - left: 0; - width: 100%; - height: 100%; -} - -.uploader .uploadIcon { - content: ''; - display: block; - margin-right: 8px; - width: 24px; - height: 24px; - background: center no-repeat url(../assets/upload-cloud-01.svg); - background-size: contain; -} - -.uploader .browse { - @apply pl-1 cursor-pointer; - color: #155eef; -} - -.fileList { - @apply space-y-2; -} - .file { @apply box-border relative flex items-center justify-between; padding: 8px 12px 8px 8px; @@ -193,4 +128,4 @@ .file:hover .actionWrapper .remove { display: block; -} \ No newline at end of file +} diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index adb4bed0d1..e42a24cfef 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -3,10 +3,12 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import useSWR from 'swr' -import s from './index.module.css' +import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react' +import DocumentFileIcon from '../../common/document-file-icon' import cn from '@/utils/classnames' import type { CustomFile as File, FileItem } from '@/models/datasets' import { ToastContext } from '@/app/components/base/toast' +import SimplePieChart from '@/app/components/base/simple-pie-chart' import { upload } from '@/service/base' import { fetchFileUploadConfig } from '@/service/common' @@ -14,6 +16,8 @@ import { fetchSupportFileTypes } from '@/service/datasets' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n/language' import { IS_CE_EDITION } from '@/config' +import { useAppContext } from '@/context/app-context' +import { Theme } from '@/types/app' const FILES_NUMBER_LIMIT = 20 @@ -222,6 +226,9 @@ const FileUploader = ({ initialUpload(files.filter(isValid)) }, [isValid, initialUpload]) + const { theme } = useAppContext() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + useEffect(() => { dropRef.current?.addEventListener('dragenter', handleDragEnter) dropRef.current?.addEventListener('dragover', handleDragOver) @@ -236,12 +243,12 @@ const FileUploader = ({ }, [handleDrop]) return ( -
+
{!hideUpload && ( )} -
{t('datasetCreation.stepOne.uploader.title')}
- {!hideUpload && ( +
{t('datasetCreation.stepOne.uploader.title')}
+ + {!hideUpload && ( +
+
+ -
-
- {t('datasetCreation.stepOne.uploader.button')} - + {supportTypes.length > 0 && ( + + )}
-
{t('datasetCreation.stepOne.uploader.tip', { +
{t('datasetCreation.stepOne.uploader.tip', { size: fileUploadConfig.file_size_limit, supportTypes: supportTypesShowNames, })}
- {dragging &&
} + {dragging &&
}
)} -
+
+ {fileList.map((fileItem, index) => (
fileItem.file?.id && onPreview(fileItem.file)} className={cn( - s.file, - fileItem.progress < 100 && s.uploading, + 'flex items-center h-12 max-w-[640px] bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary border border-components-panel-border rounded-lg shadow-xs', + // 'border-state-destructive-border bg-state-destructive-hover', )} > - {fileItem.progress < 100 && ( -
- )} -
-
-
{fileItem.file.name}
-
{getFileSize(fileItem.file.size)}
+
+
-
+
+
+
{fileItem.file.name}
+
+
+ {getFileType(fileItem.file)} + · + {getFileSize(fileItem.file.size)} + {/* · + 10k characters */} +
+
+
+ {/* + + */} {(fileItem.progress < 100 && fileItem.progress >= 0) && ( -
{`${fileItem.progress}%`}
- )} - {fileItem.progress === 100 && ( -
{ - e.stopPropagation() - removeFile(fileItem.fileID) - }} /> + //
{`${fileItem.progress}%`}
+ )} + { + e.stopPropagation() + removeFile(fileItem.fileID) + }}> + +
))} diff --git a/web/app/components/datasets/create/icons.ts b/web/app/components/datasets/create/icons.ts new file mode 100644 index 0000000000..80c4b6c944 --- /dev/null +++ b/web/app/components/datasets/create/icons.ts @@ -0,0 +1,16 @@ +import GoldIcon from './assets/gold.svg' +import Piggybank from './assets/piggy-bank-mod.svg' +import Selection from './assets/selection-mod.svg' +import Research from './assets/research-mod.svg' +import PatternRecognition from './assets/pattern-recognition-mod.svg' + +export const indexMethodIcon = { + high_quality: GoldIcon, + economical: Piggybank, +} + +export const retrievalIcon = { + vector: Selection, + fullText: Research, + hybrid: PatternRecognition, +} diff --git a/web/app/components/datasets/create/index.tsx b/web/app/components/datasets/create/index.tsx index 98098445c7..9556b9fad5 100644 --- a/web/app/components/datasets/create/index.tsx +++ b/web/app/components/datasets/create/index.tsx @@ -3,10 +3,10 @@ import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '../../base/app-unavailable' import { ModelTypeEnum } from '../../header/account-setting/model-provider-page/declarations' -import StepsNavBar from './steps-nav-bar' import StepOne from './step-one' import StepTwo from './step-two' import StepThree from './step-three' +import { Topbar } from './top-bar' import { DataSourceType } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, DataSet, FileItem, createDocumentResponse } from '@/models/datasets' import { fetchDataSource } from '@/service/common' @@ -36,6 +36,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const [dataSourceType, setDataSourceType] = useState(DataSourceType.FILE) const [step, setStep] = useState(1) const [indexingTypeCache, setIndexTypeCache] = useState('') + const [retrievalMethodCache, setRetrievalMethodCache] = useState('') const [fileList, setFiles] = useState([]) const [result, setResult] = useState() const [hasError, setHasError] = useState(false) @@ -80,6 +81,9 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const updateResultCache = (res?: createDocumentResponse) => { setResult(res) } + const updateRetrievalMethodCache = (method: string) => { + setRetrievalMethodCache(method) + } const nextStep = useCallback(() => { setStep(step + 1) @@ -118,33 +122,29 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { return return ( -
-
- -
-
-
- setShowAccountSettingModal({ payload: 'data-source' })} - datasetId={datasetId} - dataSourceType={dataSourceType} - dataSourceTypeDisable={!!detail?.data_source_type} - changeType={setDataSourceType} - files={fileList} - updateFile={updateFile} - updateFileList={updateFileList} - notionPages={notionPages} - updateNotionPages={updateNotionPages} - onStepChange={nextStep} - websitePages={websitePages} - updateWebsitePages={setWebsitePages} - onWebsiteCrawlProviderChange={setWebsiteCrawlProvider} - onWebsiteCrawlJobIdChange={setWebsiteCrawlJobId} - crawlOptions={crawlOptions} - onCrawlOptionsChange={setCrawlOptions} - /> -
+
+ +
+ {step === 1 && setShowAccountSettingModal({ payload: 'data-source' })} + datasetId={datasetId} + dataSourceType={dataSourceType} + dataSourceTypeDisable={!!detail?.data_source_type} + changeType={setDataSourceType} + files={fileList} + updateFile={updateFile} + updateFileList={updateFileList} + notionPages={notionPages} + updateNotionPages={updateNotionPages} + onStepChange={nextStep} + websitePages={websitePages} + updateWebsitePages={setWebsitePages} + onWebsiteCrawlProviderChange={setWebsiteCrawlProvider} + onWebsiteCrawlJobIdChange={setWebsiteCrawlJobId} + crawlOptions={crawlOptions} + onCrawlOptionsChange={setCrawlOptions} + />} {(step === 2 && (!datasetId || (datasetId && !!detail))) && setShowAccountSettingModal({ payload: 'provider' })} @@ -158,6 +158,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { websiteCrawlJobId={websiteCrawlJobId} onStepChange={changeStep} updateIndexingTypeCache={updateIndexingTypeCache} + updateRetrievalMethodCache={updateRetrievalMethodCache} updateResultCache={updateResultCache} crawlOptions={crawlOptions} />} @@ -165,6 +166,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { datasetId={datasetId} datasetName={detail?.name} indexingType={detail?.indexing_technique || indexingTypeCache} + retrievalMethod={detail?.retrieval_model_dict?.search_method || retrievalMethodCache} creationCache={result} />}
diff --git a/web/app/components/datasets/create/notion-page-preview/index.tsx b/web/app/components/datasets/create/notion-page-preview/index.tsx index 8225e56f04..f658f213e8 100644 --- a/web/app/components/datasets/create/notion-page-preview/index.tsx +++ b/web/app/components/datasets/create/notion-page-preview/index.tsx @@ -44,7 +44,7 @@ const NotionPagePreview = ({ }, [currentPage]) return ( -
+
{t('datasetCreation.stepOne.pagePreview')} @@ -64,7 +64,7 @@ const NotionPagePreview = ({
{loading &&
} {!loading && ( -
{previewContent}
+
{previewContent}
)}
diff --git a/web/app/components/datasets/create/step-one/index.module.css b/web/app/components/datasets/create/step-one/index.module.css index 4e3cf67cd6..bb8dd9b895 100644 --- a/web/app/components/datasets/create/step-one/index.module.css +++ b/web/app/components/datasets/create/step-one/index.module.css @@ -2,21 +2,19 @@ position: sticky; top: 0; left: 0; - padding: 42px 64px 12px; + padding: 42px 64px 12px 0; font-weight: 600; font-size: 18px; line-height: 28px; - color: #101828; } .form { position: relative; padding: 12px 64px; - background-color: #fff; } .dataSourceItem { - @apply box-border relative shrink-0 flex items-center mr-3 p-3 h-14 bg-white rounded-xl cursor-pointer; + @apply box-border relative grow shrink-0 flex items-center p-3 h-14 bg-white rounded-xl cursor-pointer; border: 0.5px solid #EAECF0; box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); font-weight: 500; @@ -24,27 +22,32 @@ line-height: 20px; color: #101828; } + .dataSourceItem:hover { background-color: #f5f8ff; border: 0.5px solid #B2CCFF; box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); } + .dataSourceItem.active { background-color: #f5f8ff; border: 1.5px solid #528BFF; box-shadow: 0px 1px 3px rgba(16, 24, 40, 0.1), 0px 1px 2px rgba(16, 24, 40, 0.06); } + .dataSourceItem.disabled { background-color: #f9fafb; border: 0.5px solid #EAECF0; box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); cursor: default; } + .dataSourceItem.disabled:hover { background-color: #f9fafb; border: 0.5px solid #EAECF0; box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); } + .comingTag { @apply flex justify-center items-center bg-white; position: absolute; @@ -59,6 +62,7 @@ line-height: 18px; color: #444CE7; } + .datasetIcon { @apply flex mr-2 w-8 h-8 rounded-lg bg-center bg-no-repeat; background-color: #F5FAFF; @@ -66,15 +70,18 @@ background-size: 16px; border: 0.5px solid #D1E9FF; } + .dataSourceItem:active .datasetIcon, .dataSourceItem:hover .datasetIcon { background-color: #F5F8FF; border: 0.5px solid #E0EAFF; } + .datasetIcon.notion { background-image: url(../assets/notion.svg); background-size: 20px; } + .datasetIcon.web { background-image: url(../assets/web.svg); } @@ -90,29 +97,12 @@ background-color: #eaecf0; } -.OtherCreationOption { - @apply flex items-center cursor-pointer; - font-weight: 500; - font-size: 13px; - line-height: 18px; - color: #155EEF; -} -.OtherCreationOption::before { - content: ''; - display: block; - margin-right: 4px; - width: 16px; - height: 16px; - background: center no-repeat url(../assets/folder-plus.svg); - background-size: contain; -} - .notionConnectionTip { display: flex; flex-direction: column; align-items: flex-start; padding: 24px; - max-width: 640px; + width: 640px; background: #F9FAFB; border-radius: 16px; } @@ -138,6 +128,7 @@ line-height: 24px; color: #374151; } + .notionConnectionTip .title::after { content: ''; position: absolute; @@ -148,6 +139,7 @@ background: center no-repeat url(../assets/Icon-3-dots.svg); background-size: contain; } + .notionConnectionTip .tip { margin-bottom: 20px; font-style: normal; @@ -155,4 +147,4 @@ font-size: 13px; line-height: 18px; color: #6B7280; -} +} \ No newline at end of file diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index 643932e9ae..2cca003b39 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -1,6 +1,7 @@ 'use client' import React, { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import { RiArrowRightLine, RiFolder6Line } from '@remixicon/react' import FilePreview from '../file-preview' import FileUploader from '../file-uploader' import NotionPagePreview from '../notion-page-preview' @@ -17,6 +18,7 @@ import { NotionPageSelector } from '@/app/components/base/notion-page-selector' import { useDatasetDetailContext } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import VectorSpaceFull from '@/app/components/billing/vector-space-full' +import classNames from '@/utils/classnames' type IStepOneProps = { datasetId?: string @@ -120,143 +122,174 @@ const StepOne = ({ return true if (isShowVectorSpaceFull) return true - return false - }, [files]) + }, [files, isShowVectorSpaceFull]) + return (
-
- { - shouldShowDataSourceTypeList && ( -
{t('datasetCreation.steps.one')}
- ) - } -
- { - shouldShowDataSourceTypeList && ( -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.FILE) - hideFilePreview() - hideNotionPagePreview() - }} - > - - {t('datasetCreation.stepOne.dataSourceType.file')} -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.NOTION) - hideFilePreview() - hideNotionPagePreview() - }} - > - - {t('datasetCreation.stepOne.dataSourceType.notion')} -
-
changeType(DataSourceType.WEB)} - > - - {t('datasetCreation.stepOne.dataSourceType.web')} -
-
- ) - } - {dataSourceType === DataSourceType.FILE && ( - <> - - {isShowVectorSpaceFull && ( -
- -
- )} - - - )} - {dataSourceType === DataSourceType.NOTION && ( - <> - {!hasConnection && } - {hasConnection && ( - <> -
- page.page_id)} - onSelect={updateNotionPages} - onPreview={updateCurrentPage} - /> +
+
+
+ { + shouldShowDataSourceTypeList && ( +
{t('datasetCreation.steps.one')}
+ ) + } + { + shouldShowDataSourceTypeList && ( +
+
{ + if (dataSourceTypeDisable) + return + changeType(DataSourceType.FILE) + hideFilePreview() + hideNotionPagePreview() + }} + > + + {t('datasetCreation.stepOne.dataSourceType.file')} +
+
{ + if (dataSourceTypeDisable) + return + changeType(DataSourceType.NOTION) + hideFilePreview() + hideNotionPagePreview() + }} + > + + {t('datasetCreation.stepOne.dataSourceType.notion')} +
+
changeType(DataSourceType.WEB)} + > + + {t('datasetCreation.stepOne.dataSourceType.web')}
- {isShowVectorSpaceFull && ( -
- -
- )} - - - )} - - )} - {dataSourceType === DataSourceType.WEB && ( - <> -
- -
- {isShowVectorSpaceFull && ( -
-
- )} - - - )} - {!datasetId && ( - <> -
-
{t('datasetCreation.stepOne.emptyDatasetCreation')}
- - )} + ) + } + {dataSourceType === DataSourceType.FILE && ( + <> + + {isShowVectorSpaceFull && ( +
+ +
+ )} +
+ {/* */} + +
+ + )} + {dataSourceType === DataSourceType.NOTION && ( + <> + {!hasConnection && } + {hasConnection && ( + <> +
+ page.page_id)} + onSelect={updateNotionPages} + onPreview={updateCurrentPage} + /> +
+ {isShowVectorSpaceFull && ( +
+ +
+ )} +
+ {/* */} + +
+ + )} + + )} + {dataSourceType === DataSourceType.WEB && ( + <> +
+ +
+ {isShowVectorSpaceFull && ( +
+ +
+ )} +
+ {/* */} + +
+ + )} + {!datasetId && ( + <> +
+ + + {t('datasetCreation.stepOne.emptyDatasetCreation')} + + + )} +
+
-
- {currentFile && } - {currentNotionPage && } - {currentWebsite && } +
+ {currentFile && } + {currentNotionPage && } + {currentWebsite && } +
) } diff --git a/web/app/components/datasets/create/step-three/index.tsx b/web/app/components/datasets/create/step-three/index.tsx index 804a196ed5..8d979616d1 100644 --- a/web/app/components/datasets/create/step-three/index.tsx +++ b/web/app/components/datasets/create/step-three/index.tsx @@ -1,45 +1,51 @@ 'use client' import React from 'react' import { useTranslation } from 'react-i18next' +import { RiBookOpenLine } from '@remixicon/react' import EmbeddingProcess from '../embedding-process' -import s from './index.module.css' -import cn from '@/utils/classnames' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import type { FullDocumentDetail, createDocumentResponse } from '@/models/datasets' +import AppIcon from '@/app/components/base/app-icon' type StepThreeProps = { datasetId?: string datasetName?: string indexingType?: string + retrievalMethod?: string creationCache?: createDocumentResponse } -const StepThree = ({ datasetId, datasetName, indexingType, creationCache }: StepThreeProps) => { +const StepThree = ({ datasetId, datasetName, indexingType, creationCache, retrievalMethod }: StepThreeProps) => { const { t } = useTranslation() const media = useBreakpoints() const isMobile = media === MediaType.mobile return ( -
-
-
+
+
+
{!datasetId && ( <> -
-
{t('datasetCreation.stepThree.creationTitle')}
-
{t('datasetCreation.stepThree.creationContent')}
-
{t('datasetCreation.stepThree.label')}
-
{datasetName || creationCache?.dataset?.name}
+
+
{t('datasetCreation.stepThree.creationTitle')}
+
{t('datasetCreation.stepThree.creationContent')}
+
+ +
+
{t('datasetCreation.stepThree.label')}
+
{datasetName || creationCache?.dataset?.name}
+
+
-
+
)} {datasetId && ( -
-
{t('datasetCreation.stepThree.additionTitle')}
-
{`${t('datasetCreation.stepThree.additionP1')} ${datasetName || creationCache?.dataset?.name} ${t('datasetCreation.stepThree.additionP2')}`}
+
+
{t('datasetCreation.stepThree.additionTitle')}
+
{`${t('datasetCreation.stepThree.additionP1')} ${datasetName || creationCache?.dataset?.name} ${t('datasetCreation.stepThree.additionP2')}`}
)}
- {!isMobile &&
-
- -
{t('datasetCreation.stepThree.sideTipTitle')}
-
{t('datasetCreation.stepThree.sideTipContent')}
+ {!isMobile && ( +
+
+
+ +
+
{t('datasetCreation.stepThree.sideTipTitle')}
+
{t('datasetCreation.stepThree.sideTipContent')}
+
-
} + )}
) } diff --git a/web/app/components/datasets/create/step-two/index.module.css b/web/app/components/datasets/create/step-two/index.module.css index f89d6d67ea..178cbeba85 100644 --- a/web/app/components/datasets/create/step-two/index.module.css +++ b/web/app/components/datasets/create/step-two/index.module.css @@ -13,18 +13,6 @@ z-index: 10; } -.form { - @apply px-16 pb-8; -} - -.form .label { - @apply pt-6 pb-2 flex items-center; - font-weight: 500; - font-size: 16px; - line-height: 24px; - color: #344054; -} - .segmentationItem { min-height: 68px; } @@ -75,6 +63,10 @@ cursor: pointer; } +.disabled { + cursor: not-allowed !important; +} + .indexItem.disabled:hover { background-color: #fcfcfd; border-color: #f2f4f7; @@ -87,8 +79,7 @@ } .radioItem { - @apply relative mb-2 rounded-xl border border-gray-100 cursor-pointer; - background-color: #fcfcfd; + @apply relative mb-2 rounded-xl border border-components-option-card-option-border cursor-pointer bg-components-option-card-option-bg; } .radioItem.segmentationItem.custom { @@ -146,7 +137,7 @@ } .typeIcon.economical { - background-image: url(../assets/piggy-bank-01.svg); + background-image: url(../assets/piggy-bank-mod.svg); } .radioItem .radio { @@ -247,7 +238,7 @@ } .ruleItem { - @apply flex items-center; + @apply flex items-center py-1.5; } .formFooter { @@ -394,19 +385,6 @@ max-width: 524px; } -.previewHeader { - position: sticky; - top: 0; - left: 0; - padding-top: 42px; - background-color: #fff; - font-weight: 600; - font-size: 18px; - line-height: 28px; - color: #101828; - z-index: 10; -} - /* * `fixed` must under `previewHeader` because of style override would not work */ @@ -432,4 +410,4 @@ font-size: 12px; line-height: 18px; } -} \ No newline at end of file +} diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index f915c68fef..0d7202967a 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -1,65 +1,80 @@ 'use client' -import React, { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react' +import type { FC, PropsWithChildren } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { useBoolean } from 'ahooks' -import { XMarkIcon } from '@heroicons/react/20/solid' -import { RocketLaunchIcon } from '@heroicons/react/24/outline' import { - RiCloseLine, + RiAlertFill, + RiArrowLeftLine, + RiSearchEyeLine, } from '@remixicon/react' import Link from 'next/link' -import { groupBy } from 'lodash-es' -import PreviewItem, { PreviewType } from './preview-item' -import LanguageSelect from './language-select' +import Image from 'next/image' +import { useHover } from 'ahooks' +import SettingCog from '../assets/setting-gear-mod.svg' +import OrangeEffect from '../assets/option-card-effect-orange.svg' +import FamilyMod from '../assets/family-mod.svg' +import Note from '../assets/note-mod.svg' +import FileList from '../assets/file-list-3-fill.svg' +import { indexMethodIcon } from '../icons' +import { PreviewContainer } from '../../preview/container' +import { ChunkContainer, QAPreview } from '../../chunk' +import { PreviewHeader } from '../../preview/header' +import { FormattedText } from '../../formatted-text/formatted' +import { PreviewSlice } from '../../formatted-text/flavours/preview-slice' +import PreviewDocumentPicker from '../../common/document-picker/preview-document-picker' import s from './index.module.css' import unescape from './unescape' import escape from './escape' +import { OptionCard } from './option-card' +import LanguageSelect from './language-select' +import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' import cn from '@/utils/classnames' -import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' -import { - createDocument, - createFirstDocument, - fetchFileIndexingEstimate as didFetchFileIndexingEstimate, - fetchDefaultProcessRule, -} from '@/service/datasets' +import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' + import Button from '@/app/components/base/button' -import Input from '@/app/components/base/input' -import Loading from '@/app/components/base/loading' import FloatRightContainer from '@/app/components/base/float-right-container' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import { type RetrievalConfig } from '@/types/app' import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import Toast from '@/app/components/base/toast' -import { formatNumber } from '@/utils/format' import type { NotionPage } from '@/models/common' import { DataSourceProvider } from '@/models/common' -import { DataSourceType, DocForm } from '@/models/datasets' -import NotionIcon from '@/app/components/base/notion-icon' -import Switch from '@/app/components/base/switch' -import { MessageChatSquare } from '@/app/components/base/icons/src/public/common' +import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets' import { useDatasetDetailContext } from '@/context/dataset-detail' import I18n from '@/context/i18n' -import { IS_CE_EDITION } from '@/config' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import Tooltip from '@/app/components/base/tooltip' import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { Globe01 } from '@/app/components/base/icons/src/vender/line/mapsAndTravel' +import Checkbox from '@/app/components/base/checkbox' +import RadioCard from '@/app/components/base/radio-card' +import { IS_CE_EDITION } from '@/config' +import Divider from '@/app/components/base/divider' +import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' +import Badge from '@/app/components/base/badge' +import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton' +import Tooltip from '@/app/components/base/tooltip' +import CustomDialog from '@/app/components/base/dialog' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' + +const TextLabel: FC = (props) => { + return +} -type ValueOf = T[keyof T] type StepTwoProps = { isSetting?: boolean documentDetail?: FullDocumentDetail isAPIKeySet: boolean onSetting: () => void datasetId?: string - indexingType?: ValueOf + indexingType?: IndexingType + retrievalMethod?: string dataSourceType: DataSourceType files: CustomFile[] notionPages?: NotionPage[] @@ -69,21 +84,48 @@ type StepTwoProps = { websiteCrawlJobId?: string onStepChange?: (delta: number) => void updateIndexingTypeCache?: (type: string) => void + updateRetrievalMethodCache?: (method: string) => void updateResultCache?: (res: createDocumentResponse) => void onSave?: () => void onCancel?: () => void } -enum SegmentType { +export enum SegmentType { AUTO = 'automatic', CUSTOM = 'custom', } -enum IndexingType { +export enum IndexingType { QUALIFIED = 'high_quality', ECONOMICAL = 'economy', } const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' +const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500 +const DEFAULT_OVERLAP = 50 + +type ParentChildConfig = { + chunkForContext: ParentMode + parent: { + delimiter: string + maxLength: number + } + child: { + delimiter: string + maxLength: number + } +} + +const defaultParentChildConfig: ParentChildConfig = { + chunkForContext: 'paragraph', + parent: { + delimiter: '\\n\\n', + maxLength: 500, + }, + child: { + delimiter: '\\n', + maxLength: 200, + }, +} const StepTwo = ({ isSetting, @@ -104,6 +146,7 @@ const StepTwo = ({ updateResultCache, onSave, onCancel, + updateRetrievalMethodCache, }: StepTwoProps) => { const { t } = useTranslation() const { locale } = useContext(I18n) @@ -111,66 +154,166 @@ const StepTwo = ({ const isMobile = media === MediaType.mobile const { dataset: currentDataset, mutateDatasetRes } = useDatasetDetailContext() + + const isInUpload = Boolean(currentDataset) + const isUploadInEmptyDataset = isInUpload && !currentDataset?.doc_form + const isNotUploadInEmptyDataset = !isUploadInEmptyDataset + const isInInit = !isInUpload && !isSetting + const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type - const scrollRef = useRef(null) - const [scrolled, setScrolled] = useState(false) - const previewScrollRef = useRef(null) - const [previewScrolled, setPreviewScrolled] = useState(false) - const [segmentationType, setSegmentationType] = useState(SegmentType.AUTO) + const [segmentationType, setSegmentationType] = useState(SegmentType.CUSTOM) const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) - const setSegmentIdentifier = useCallback((value: string) => { - doSetSegmentIdentifier(value ? escape(value) : DEFAULT_SEGMENT_IDENTIFIER) + const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { + doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) }, []) - const [maxChunkLength, setMaxChunkLength] = useState(4000) // default chunk length + const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) - const [overlap, setOverlap] = useState(50) + const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) const [rules, setRules] = useState([]) const [defaultConfig, setDefaultConfig] = useState() const hasSetIndexType = !!indexingType - const [indexType, setIndexType] = useState>( + const [indexType, setIndexType] = useState( (indexingType || isAPIKeySet) ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL, ) - const [isLanguageSelectDisabled, setIsLanguageSelectDisabled] = useState(false) - const [docForm, setDocForm] = useState( - (datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT, + + const [previewFile, setPreviewFile] = useState( + (datasetId && documentDetail) + ? documentDetail.file + : files[0], ) + const [previewNotionPage, setPreviewNotionPage] = useState( + (datasetId && documentDetail) + ? documentDetail.notion_page + : notionPages[0], + ) + + const [previewWebsitePage, setPreviewWebsitePage] = useState( + (datasetId && documentDetail) + ? documentDetail.website_page + : websitePages[0], + ) + + // QA Related + const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false) + const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) + const [docForm, setDocForm] = useState( + (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, + ) + const handleChangeDocform = (value: ChunkingMode) => { + if (value === ChunkingMode.qa && indexType === IndexingType.ECONOMICAL) { + setIsQAConfirmDialogOpen(true) + return + } + if (value === ChunkingMode.parentChild && indexType === IndexingType.ECONOMICAL) + setIndexType(IndexingType.QUALIFIED) + setDocForm(value) + // eslint-disable-next-line @typescript-eslint/no-use-before-define + currentEstimateMutation.reset() + } + const [docLanguage, setDocLanguage] = useState( (datasetId && documentDetail) ? documentDetail.doc_language : (locale !== LanguagesSupported[1] ? 'English' : 'Chinese'), ) - const [QATipHide, setQATipHide] = useState(false) - const [previewSwitched, setPreviewSwitched] = useState(false) - const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() - const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState(null) - const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState(null) - const fileIndexingEstimate = (() => { - return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate - })() - const [isCreating, setIsCreating] = useState(false) + const [parentChildConfig, setParentChildConfig] = useState(defaultParentChildConfig) - const scrollHandle = (e: Event) => { - if ((e.target as HTMLDivElement).scrollTop > 0) - setScrolled(true) + const getIndexing_technique = () => indexingType || indexType + const currentDocForm = currentDataset?.doc_form || docForm - else - setScrolled(false) + const getProcessRule = (): ProcessRule => { + if (currentDocForm === ChunkingMode.parentChild) { + return { + rules: { + pre_processing_rules: rules, + segmentation: { + separator: unescape( + parentChildConfig.parent.delimiter, + ), + max_tokens: parentChildConfig.parent.maxLength, + }, + parent_mode: parentChildConfig.chunkForContext, + subchunk_segmentation: { + separator: unescape(parentChildConfig.child.delimiter), + max_tokens: parentChildConfig.child.maxLength, + }, + }, + mode: 'hierarchical', + } as ProcessRule + } + return { + rules: { + pre_processing_rules: rules, + segmentation: { + separator: unescape(segmentIdentifier), + max_tokens: maxChunkLength, + chunk_overlap: overlap, + }, + }, // api will check this. It will be removed after api refactored. + mode: segmentationType, + } as ProcessRule } - const previewScrollHandle = (e: Event) => { - if ((e.target as HTMLDivElement).scrollTop > 0) - setPreviewScrolled(true) + const fileIndexingEstimateQuery = useFetchFileIndexingEstimateForFile({ + docForm: currentDocForm, + docLanguage, + dataSourceType: DataSourceType.FILE, + files: previewFile + ? [files.find(file => file.name === previewFile.name)!] + : files, + indexingTechnique: getIndexing_technique() as any, + processRule: getProcessRule(), + dataset_id: datasetId!, + }) + const notionIndexingEstimateQuery = useFetchFileIndexingEstimateForNotion({ + docForm: currentDocForm, + docLanguage, + dataSourceType: DataSourceType.NOTION, + notionPages: [previewNotionPage], + indexingTechnique: getIndexing_technique() as any, + processRule: getProcessRule(), + dataset_id: datasetId || '', + }) - else - setPreviewScrolled(false) - } - const getFileName = (name: string) => { - const arr = name.split('.') - return arr.slice(0, -1).join('.') - } + const websiteIndexingEstimateQuery = useFetchFileIndexingEstimateForWeb({ + docForm: currentDocForm, + docLanguage, + dataSourceType: DataSourceType.WEB, + websitePages: [previewWebsitePage], + crawlOptions, + websiteCrawlProvider, + websiteCrawlJobId, + indexingTechnique: getIndexing_technique() as any, + processRule: getProcessRule(), + dataset_id: datasetId || '', + }) + + const currentEstimateMutation = dataSourceType === DataSourceType.FILE + ? fileIndexingEstimateQuery + : dataSourceType === DataSourceType.NOTION + ? notionIndexingEstimateQuery + : websiteIndexingEstimateQuery + + const fetchEstimate = useCallback(() => { + if (dataSourceType === DataSourceType.FILE) + fileIndexingEstimateQuery.mutate() + + if (dataSourceType === DataSourceType.NOTION) + notionIndexingEstimateQuery.mutate() + + if (dataSourceType === DataSourceType.WEB) + websiteIndexingEstimateQuery.mutate() + }, [dataSourceType, fileIndexingEstimateQuery, notionIndexingEstimateQuery, websiteIndexingEstimateQuery]) + + const estimate + = dataSourceType === DataSourceType.FILE + ? fileIndexingEstimateQuery.data + : dataSourceType === DataSourceType.NOTION + ? notionIndexingEstimateQuery.data + : websiteIndexingEstimateQuery.data const getRuleName = (key: string) => { if (key === 'remove_extra_spaces') @@ -198,128 +341,20 @@ const StepTwo = ({ if (defaultConfig) { setSegmentIdentifier(defaultConfig.segmentation.separator) setMaxChunkLength(defaultConfig.segmentation.max_tokens) - setOverlap(defaultConfig.segmentation.chunk_overlap) + setOverlap(defaultConfig.segmentation.chunk_overlap!) setRules(defaultConfig.pre_processing_rules) } + setParentChildConfig(defaultParentChildConfig) } - const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT, language?: string) => { - // eslint-disable-next-line @typescript-eslint/no-use-before-define - const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm, language)!) - if (segmentationType === SegmentType.CUSTOM) - setCustomFileIndexingEstimate(res) - else - setAutomaticFileIndexingEstimate(res) - } - - const confirmChangeCustomConfig = () => { - if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) { - Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) + const updatePreview = () => { + if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) { + Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) return } - setCustomFileIndexingEstimate(null) - setShowPreview() - fetchFileIndexingEstimate() - setPreviewSwitched(false) + fetchEstimate() } - const getIndexing_technique = () => indexingType || indexType - - const getProcessRule = () => { - const processRule: ProcessRule = { - rules: {} as any, // api will check this. It will be removed after api refactored. - mode: segmentationType, - } - if (segmentationType === SegmentType.CUSTOM) { - const ruleObj = { - pre_processing_rules: rules, - segmentation: { - separator: unescape(segmentIdentifier), - max_tokens: maxChunkLength, - chunk_overlap: overlap, - }, - } - processRule.rules = ruleObj - } - return processRule - } - - const getNotionInfo = () => { - const workspacesMap = groupBy(notionPages, 'workspace_id') - const workspaces = Object.keys(workspacesMap).map((workspaceId) => { - return { - workspaceId, - pages: workspacesMap[workspaceId], - } - }) - return workspaces.map((workspace) => { - return { - workspace_id: workspace.workspaceId, - pages: workspace.pages.map((page) => { - const { page_id, page_name, page_icon, type } = page - return { - page_id, - page_name, - page_icon, - type, - } - }), - } - }) as NotionInfo[] - } - - const getWebsiteInfo = () => { - return { - provider: websiteCrawlProvider, - job_id: websiteCrawlJobId, - urls: websitePages.map(page => page.source_url), - only_main_content: crawlOptions?.only_main_content, - } - } - - const getFileIndexingEstimateParams = (docForm: DocForm, language?: string): IndexingEstimateParams | undefined => { - if (dataSourceType === DataSourceType.FILE) { - return { - info_list: { - data_source_type: dataSourceType, - file_info_list: { - file_ids: files.map(file => file.id) as string[], - }, - }, - indexing_technique: getIndexing_technique() as string, - process_rule: getProcessRule(), - doc_form: docForm, - doc_language: language || docLanguage, - dataset_id: datasetId as string, - } - } - if (dataSourceType === DataSourceType.NOTION) { - return { - info_list: { - data_source_type: dataSourceType, - notion_info_list: getNotionInfo(), - }, - indexing_technique: getIndexing_technique() as string, - process_rule: getProcessRule(), - doc_form: docForm, - doc_language: language || docLanguage, - dataset_id: datasetId as string, - } - } - if (dataSourceType === DataSourceType.WEB) { - return { - info_list: { - data_source_type: dataSourceType, - website_info_list: getWebsiteInfo(), - }, - indexing_technique: getIndexing_technique() as string, - process_rule: getProcessRule(), - doc_form: docForm, - doc_language: language || docLanguage, - dataset_id: datasetId as string, - } - } - } const { modelList: rerankModelList, defaultModel: rerankDefaultModel, @@ -351,13 +386,14 @@ const StepTwo = ({ if (isSetting) { params = { original_document_id: documentDetail?.id, - doc_form: docForm, + doc_form: currentDocForm, doc_language: docLanguage, process_rule: getProcessRule(), // eslint-disable-next-line @typescript-eslint/no-use-before-define retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. embedding_model: embeddingModel.model, // Readonly embedding_model_provider: embeddingModel.provider, // Readonly + indexing_technique: getIndexing_technique(), } as CreateDocumentReq } else { // create @@ -377,8 +413,12 @@ const StepTwo = ({ } const postRetrievalConfig = ensureRerankModelSelected({ rerankDefaultModel: rerankDefaultModel!, - // eslint-disable-next-line @typescript-eslint/no-use-before-define - retrievalConfig, + retrievalConfig: { + // eslint-disable-next-line @typescript-eslint/no-use-before-define + ...retrievalConfig, + // eslint-disable-next-line @typescript-eslint/no-use-before-define + reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, + }, indexMethod: indexMethod as string, }) params = { @@ -390,7 +430,7 @@ const StepTwo = ({ }, indexing_technique: getIndexing_technique(), process_rule: getProcessRule(), - doc_form: docForm, + doc_form: currentDocForm, doc_language: docLanguage, retrieval_model: postRetrievalConfig, @@ -403,29 +443,36 @@ const StepTwo = ({ } } if (dataSourceType === DataSourceType.NOTION) - params.data_source.info_list.notion_info_list = getNotionInfo() + params.data_source.info_list.notion_info_list = getNotionInfo(notionPages) - if (dataSourceType === DataSourceType.WEB) - params.data_source.info_list.website_info_list = getWebsiteInfo() + if (dataSourceType === DataSourceType.WEB) { + params.data_source.info_list.website_info_list = getWebsiteInfo({ + websiteCrawlProvider, + websiteCrawlJobId, + websitePages, + }) + } } return params } - const getRules = async () => { - try { - const res = await fetchDefaultProcessRule({ url: '/datasets/process-rule' }) - const separator = res.rules.segmentation.separator + const fetchDefaultProcessRuleMutation = useFetchDefaultProcessRule({ + onSuccess(data) { + const separator = data.rules.segmentation.separator setSegmentIdentifier(separator) - setMaxChunkLength(res.rules.segmentation.max_tokens) - setLimitMaxChunkLength(res.limits.indexing_max_segmentation_tokens_length) - setOverlap(res.rules.segmentation.chunk_overlap) - setRules(res.rules.pre_processing_rules) - setDefaultConfig(res.rules) - } - catch (err) { - console.log(err) - } - } + setMaxChunkLength(data.rules.segmentation.max_tokens) + setOverlap(data.rules.segmentation.chunk_overlap!) + setRules(data.rules.pre_processing_rules) + setDefaultConfig(data.rules) + setLimitMaxChunkLength(data.limits.indexing_max_segmentation_tokens_length) + }, + onError(error) { + Toast.notify({ + type: 'error', + message: `${error}`, + }) + }, + }) const getRulesFromDetail = () => { if (documentDetail) { @@ -435,7 +482,7 @@ const StepTwo = ({ const overlap = rules.segmentation.chunk_overlap setSegmentIdentifier(separator) setMaxChunkLength(max) - setOverlap(overlap) + setOverlap(overlap!) setRules(rules.pre_processing_rules) setDefaultConfig(rules) } @@ -443,119 +490,81 @@ const StepTwo = ({ const getDefaultMode = () => { if (documentDetail) + // @ts-expect-error fix after api refactored setSegmentationType(documentDetail.dataset_process_rule.mode) } - const createHandle = async () => { - if (isCreating) - return - setIsCreating(true) - try { - let res - const params = getCreationParams() - if (!params) - return false - - setIsCreating(true) - if (!datasetId) { - res = await createFirstDocument({ - body: params as CreateDocumentReq, - }) - updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) - updateResultCache && updateResultCache(res) - } - else { - res = await createDocument({ - datasetId, - body: params as CreateDocumentReq, - }) - updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) - updateResultCache && updateResultCache(res) - } - if (mutateDatasetRes) - mutateDatasetRes() - onStepChange && onStepChange(+1) - isSetting && onSave && onSave() - } - catch (err) { + const createFirstDocumentMutation = useCreateFirstDocument({ + onError(error) { Toast.notify({ type: 'error', - message: `${err}`, + message: `${error}`, + }) + }, + }) + const createDocumentMutation = useCreateDocument(datasetId!, { + onError(error) { + Toast.notify({ + type: 'error', + message: `${error}`, + }) + }, + }) + + const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending + + const createHandle = async () => { + const params = getCreationParams() + if (!params) + return false + + if (!datasetId) { + await createFirstDocumentMutation.mutateAsync( + params, + { + onSuccess(data) { + updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) + updateResultCache && updateResultCache(data) + // eslint-disable-next-line @typescript-eslint/no-use-before-define + updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) + }, + }, + ) + } + else { + await createDocumentMutation.mutateAsync(params, { + onSuccess(data) { + updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) + updateResultCache && updateResultCache(data) + }, }) } - finally { - setIsCreating(false) - } - } - - const handleSwitch = (state: boolean) => { - if (state) - setDocForm(DocForm.QA) - else - setDocForm(DocForm.TEXT) - } - - const previewSwitch = async (language?: string) => { - setPreviewSwitched(true) - setIsLanguageSelectDisabled(true) - if (segmentationType === SegmentType.AUTO) - setAutomaticFileIndexingEstimate(null) - else - setCustomFileIndexingEstimate(null) - try { - await fetchFileIndexingEstimate(DocForm.QA, language) - } - finally { - setIsLanguageSelectDisabled(false) - } - } - - const handleSelect = (language: string) => { - setDocLanguage(language) - // Switch language, re-cutter - if (docForm === DocForm.QA && previewSwitched) - previewSwitch(language) + if (mutateDatasetRes) + mutateDatasetRes() + onStepChange && onStepChange(+1) + isSetting && onSave && onSave() } const changeToEconomicalType = () => { - if (!hasSetIndexType) { + if (docForm !== ChunkingMode.text) + return + + if (!hasSetIndexType) setIndexType(IndexingType.ECONOMICAL) - setDocForm(DocForm.TEXT) - } } useEffect(() => { // fetch rules if (!isSetting) { - getRules() + fetchDefaultProcessRuleMutation.mutate('/datasets/process-rule') } else { getRulesFromDetail() getDefaultMode() } + // eslint-disable-next-line react-hooks/exhaustive-deps }, []) - useEffect(() => { - scrollRef.current?.addEventListener('scroll', scrollHandle) - return () => { - scrollRef.current?.removeEventListener('scroll', scrollHandle) - } - }, []) - - useLayoutEffect(() => { - if (showPreview) { - previewScrollRef.current?.addEventListener('scroll', previewScrollHandle) - return () => { - previewScrollRef.current?.removeEventListener('scroll', previewScrollHandle) - } - } - }, [showPreview]) - - useEffect(() => { - if (indexingType === IndexingType.ECONOMICAL && docForm === DocForm.QA) - setDocForm(DocForm.TEXT) - }, [indexingType, docForm]) - useEffect(() => { // get indexing type by props if (indexingType) @@ -565,20 +574,6 @@ const StepTwo = ({ setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) }, [isAPIKeySet, indexingType, datasetId]) - useEffect(() => { - if (segmentationType === SegmentType.AUTO) { - setAutomaticFileIndexingEstimate(null) - !isMobile && setShowPreview() - fetchFileIndexingEstimate() - setPreviewSwitched(false) - } - else { - hidePreview() - setCustomFileIndexingEstimate(null) - setPreviewSwitched(false) - } - }, [segmentationType, indexType]) - const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { search_method: RETRIEVE_METHOD.semantic, reranking_enable: false, @@ -591,433 +586,589 @@ const StepTwo = ({ score_threshold: 0.5, } as RetrievalConfig) + const economyDomRef = useRef(null) + const isHoveringEconomy = useHover(economyDomRef) + return (
-
-
- {t('datasetCreation.steps.two')} - {(isMobile || !showPreview) && ( - - )} -
-
-
{t('datasetCreation.stepTwo.segmentation')}
-
-
setSegmentationType(SegmentType.AUTO)} - > - - -
-
{t('datasetCreation.stepTwo.auto')}
-
{t('datasetCreation.stepTwo.autoDescription')}
-
-
-
setSegmentationType(SegmentType.CUSTOM)} - > - - -
-
{t('datasetCreation.stepTwo.custom')}
-
{t('datasetCreation.stepTwo.customDescription')}
-
- {segmentationType === SegmentType.CUSTOM && ( -
-
-
-
- {t('datasetCreation.stepTwo.separator')} - - {t('datasetCreation.stepTwo.separatorTip')} -
- } - /> -
- setSegmentIdentifier(e.target.value)} - /> -
-
-
-
-
{t('datasetCreation.stepTwo.maxLength')}
- setMaxChunkLength(parseInt(e.target.value.replace(/^0+/, ''), 10))} - /> -
-
-
-
-
- {t('datasetCreation.stepTwo.overlap')} - - {t('datasetCreation.stepTwo.overlapTip')} -
- } - /> -
- setOverlap(parseInt(e.target.value.replace(/^0+/, ''), 10))} - /> -
-
-
-
-
{t('datasetCreation.stepTwo.rules')}
- {rules.map(rule => ( -
- ruleChangeHandle(rule.id)} className="w-4 h-4 rounded border-gray-300 text-blue-700 focus:ring-blue-700" /> - -
- ))} -
-
-
- - -
-
- )} -
-
-
{t('datasetCreation.stepTwo.indexMode')}
-
-
- {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.QUALIFIED)) && ( -
{ - if (isAPIKeySet) - setIndexType(IndexingType.QUALIFIED) - }} - > - - {!hasSetIndexType && } -
-
- {t('datasetCreation.stepTwo.qualified')} - {!hasSetIndexType && {t('datasetCreation.stepTwo.recommend')}} -
-
{t('datasetCreation.stepTwo.qualifiedTip')}
-
- {!isAPIKeySet && ( -
- {t('datasetCreation.stepTwo.warning')}  - {t('datasetCreation.stepTwo.click')} -
- )} -
- )} - - {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.ECONOMICAL)) && ( -
- - {!hasSetIndexType && } -
-
{t('datasetCreation.stepTwo.economical')}
-
{t('datasetCreation.stepTwo.economicalTip')}
-
-
- )} -
- {hasSetIndexType && indexType === IndexingType.ECONOMICAL && ( -
- {t('datasetCreation.stepTwo.indexSettingTip')} - {t('datasetCreation.stepTwo.datasetSettingLink')} -
- )} - {IS_CE_EDITION && indexType === IndexingType.QUALIFIED && ( -
-
-
- -
-
-
{t('datasetCreation.stepTwo.QATitle')}
-
- {t('datasetCreation.stepTwo.QALanguage')} - -
-
-
- -
-
- {docForm === DocForm.QA && !QATipHide && ( -
- {t('datasetCreation.stepTwo.QATip')} - setQATipHide(true)} /> -
- )} -
- )} - {/* Embedding model */} - {indexType === IndexingType.QUALIFIED && ( -
-
{t('datasetSettings.form.embeddingModel')}
- { - setEmbeddingModel(model) - }} +
+
{t('datasetCreation.stepTwo.segmentation')}
+ {((isInUpload && [ChunkingMode.text, ChunkingMode.qa].includes(currentDataset!.doc_form)) + || isUploadInEmptyDataset + || isInInit) + && } + activeHeaderClassName='bg-dataset-option-card-blue-gradient' + description={t('datasetCreation.stepTwo.generalTip')} + isActive={ + [ChunkingMode.text, ChunkingMode.qa].includes(currentDocForm) + } + onSwitched={() => + handleChangeDocform(ChunkingMode.text) + } + actions={ + <> + + + + } + noHighlight={isInUpload && isNotUploadInEmptyDataset} + > +
+
+ setSegmentIdentifier(e.target.value, true)} + /> + + - {!!datasetId && ( -
- {t('datasetCreation.stepTwo.indexSettingTip')} - {t('datasetCreation.stepTwo.datasetSettingLink')} -
- )}
- )} - {/* Retrieval Method Config */} -
- {!datasetId - ? ( -
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- {t('datasetSettings.form.retrievalSetting.learnMore')} - {t('datasetSettings.form.retrievalSetting.longDescription')} +
+
+
+ {t('datasetCreation.stepTwo.rules')} +
+ +
+
+ {rules.map(rule => ( +
{ + ruleChangeHandle(rule.id) + }}> + +
-
- ) - : ( -
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- )} - -
- { - getIndexing_technique() === IndexingType.QUALIFIED - ? ( - + +
+
{ + if (currentDataset?.doc_form) + return + if (docForm === ChunkingMode.qa) + handleChangeDocform(ChunkingMode.text) + else + handleChangeDocform(ChunkingMode.qa) + }}> + + +
+ - ) - : ( - - ) - } + +
+ {currentDocForm === ChunkingMode.qa && ( +
+ + + {t('datasetCreation.stepTwo.QATip')} + +
+ )} + } +
- -
-
- {dataSourceType === DataSourceType.FILE && ( - <> -
{t('datasetCreation.stepTwo.fileSource')}
-
- - {getFileName(files[0].name || '')} - {files.length > 1 && ( - - {t('datasetCreation.stepTwo.other')} - {files.length - 1} - {t('datasetCreation.stepTwo.fileUnit')} - - )} -
- - )} - {dataSourceType === DataSourceType.NOTION && ( - <> -
{t('datasetCreation.stepTwo.notionSource')}
-
- } + { + ( + (isInUpload && currentDataset!.doc_form === ChunkingMode.parentChild) + || isUploadInEmptyDataset + || isInInit + ) + && } + effectImg={OrangeEffect.src} + activeHeaderClassName='bg-dataset-option-card-orange-gradient' + description={t('datasetCreation.stepTwo.parentChildTip')} + isActive={currentDocForm === ChunkingMode.parentChild} + onSwitched={() => handleChangeDocform(ChunkingMode.parentChild)} + actions={ + <> + + + + } + noHighlight={isInUpload && isNotUploadInEmptyDataset} + > +
+
+
+
+ {t('datasetCreation.stepTwo.parentChunkForContext')} +
+ +
+ } + title={t('datasetCreation.stepTwo.paragraph')} + description={t('datasetCreation.stepTwo.paragraphTip')} + isChosen={parentChildConfig.chunkForContext === 'paragraph'} + onChosen={() => setParentChildConfig( + { + ...parentChildConfig, + chunkForContext: 'paragraph', + }, + )} + chosenConfig={ +
+ setParentChildConfig({ + ...parentChildConfig, + parent: { + ...parentChildConfig.parent, + delimiter: e.target.value ? escape(e.target.value) : '', + }, + })} + /> + setParentChildConfig({ + ...parentChildConfig, + parent: { + ...parentChildConfig.parent, + maxLength: value, + }, + })} /> - {notionPages[0]?.page_name} - {notionPages.length > 1 && ( - - {t('datasetCreation.stepTwo.other')} - {notionPages.length - 1} - {t('datasetCreation.stepTwo.notionUnit')} - - )}
- - )} - {dataSourceType === DataSourceType.WEB && ( - <> -
{t('datasetCreation.stepTwo.websiteSource')}
-
- - {websitePages[0].source_url} - {websitePages.length > 1 && ( - - {t('datasetCreation.stepTwo.other')} - {websitePages.length - 1} - {t('datasetCreation.stepTwo.webpageUnit')} - - )} -
- - )} -
-
-
-
{t('datasetCreation.stepTwo.estimateSegment')}
-
- { - fileIndexingEstimate - ? ( -
{formatNumber(fileIndexingEstimate.total_segments)}
- ) - : ( -
{t('datasetCreation.stepTwo.calculating')}
- ) } + /> + } + title={t('datasetCreation.stepTwo.fullDoc')} + description={t('datasetCreation.stepTwo.fullDocTip')} + onChosen={() => setParentChildConfig( + { + ...parentChildConfig, + chunkForContext: 'full-doc', + }, + )} + isChosen={parentChildConfig.chunkForContext === 'full-doc'} + /> +
+ +
+
+
+ {t('datasetCreation.stepTwo.childChunkForRetrieval')} +
+ +
+
+ setParentChildConfig({ + ...parentChildConfig, + child: { + ...parentChildConfig.child, + delimiter: e.target.value ? escape(e.target.value) : '', + }, + })} + /> + setParentChildConfig({ + ...parentChildConfig, + child: { + ...parentChildConfig.child, + maxLength: value, + }, + })} + /> +
+
+
+
+
+ {t('datasetCreation.stepTwo.rules')} +
+ +
+
+ {rules.map(rule => ( +
{ + ruleChangeHandle(rule.id) + }}> + + +
+ ))}
- {!isSetting - ? ( -
- -
- + } + +
{t('datasetCreation.stepTwo.indexMode')}
+
+ {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.QUALIFIED)) && ( + + {t('datasetCreation.stepTwo.qualified')} + + {t('datasetCreation.stepTwo.recommend')} + + + {!hasSetIndexType && } + +
} + description={t('datasetCreation.stepTwo.qualifiedTip')} + icon={} + isActive={!hasSetIndexType && indexType === IndexingType.QUALIFIED} + disabled={!isAPIKeySet || hasSetIndexType} + onSwitched={() => { + if (isAPIKeySet) + setIndexType(IndexingType.QUALIFIED) + }} + /> + )} + + {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.ECONOMICAL)) && ( + <> + setIsQAConfirmDialogOpen(false)} className='w-[432px]'> +
+

+ {t('datasetCreation.stepTwo.qaSwitchHighQualityTipTitle')} +

+

+ {t('datasetCreation.stepTwo.qaSwitchHighQualityTipContent')} +

+
+
+ +
- ) - : ( -
- - -
- )} -
+ + + + } + isActive={!hasSetIndexType && indexType === IndexingType.ECONOMICAL} + disabled={!isAPIKeySet || hasSetIndexType || docForm !== ChunkingMode.text} + ref={economyDomRef} + onSwitched={() => { + if (isAPIKeySet && docForm === ChunkingMode.text) + setIndexType(IndexingType.ECONOMICAL) + }} + /> + + +
+ { + docForm === ChunkingMode.qa + ? t('datasetCreation.stepTwo.notAvailableForQA') + : t('datasetCreation.stepTwo.notAvailableForParentChild') + } +
+
+
+ )}
-
- - {showPreview &&
-
-
-
-
{t('datasetCreation.stepTwo.previewTitle')}
- {docForm === DocForm.QA && !previewSwitched && ( - - )} -
-
- -
-
- {docForm === DocForm.QA && !previewSwitched && ( -
- {t('datasetCreation.stepTwo.previewSwitchTipStart')} - {t('datasetCreation.stepTwo.previewSwitchTipEnd')} -
- )} -
-
- {previewSwitched && docForm === DocForm.QA && fileIndexingEstimate?.qa_preview && ( - <> - {fileIndexingEstimate?.qa_preview.map((item, index) => ( - - ))} - - )} - {(docForm === DocForm.TEXT || !previewSwitched) && fileIndexingEstimate?.preview && ( - <> - {fileIndexingEstimate?.preview.map((item, index) => ( - - ))} - - )} - {previewSwitched && docForm === DocForm.QA && !fileIndexingEstimate?.qa_preview && ( -
- -
- )} - {!previewSwitched && !fileIndexingEstimate?.preview && ( -
- -
- )} -
-
} - {!showPreview && ( -
-
- -
{t('datasetCreation.stepTwo.sideTipTitle')}
-
-

{t('datasetCreation.stepTwo.sideTipP1')}

-

{t('datasetCreation.stepTwo.sideTipP2')}

-

{t('datasetCreation.stepTwo.sideTipP3')}

-

{t('datasetCreation.stepTwo.sideTipP4')}

-
+ {!hasSetIndexType && indexType === IndexingType.QUALIFIED && ( +
+
+
+
+ {t('datasetCreation.stepTwo.highQualityTip')}
)} + {hasSetIndexType && indexType === IndexingType.ECONOMICAL && ( +
+ {t('datasetCreation.stepTwo.indexSettingTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')} +
+ )} + {/* Embedding model */} + {indexType === IndexingType.QUALIFIED && ( +
+
{t('datasetSettings.form.embeddingModel')}
+ { + setEmbeddingModel(model) + }} + /> + {!!datasetId && ( +
+ {t('datasetCreation.stepTwo.indexSettingTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')} +
+ )} +
+ )} + + {/* Retrieval Method Config */} +
+ {!datasetId + ? ( +
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ {t('datasetSettings.form.retrievalSetting.learnMore')} + {t('datasetSettings.form.retrievalSetting.longDescription')} +
+
+ ) + : ( +
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ )} + +
+ { + getIndexing_technique() === IndexingType.QUALIFIED + ? ( + + ) + : ( + + ) + } +
+
+ + {!isSetting + ? ( +
+ + +
+ ) + : ( +
+ + +
+ )} +
+ { }} footer={null}> + +
+ {dataSourceType === DataSourceType.FILE + && >} + onChange={(selected) => { + currentEstimateMutation.reset() + setPreviewFile(selected) + currentEstimateMutation.mutate() + }} + // when it is from setting, it just has one file + value={isSetting ? (files[0]! as Required) : previewFile} + /> + } + {dataSourceType === DataSourceType.NOTION + && ({ + id: page.page_id, + name: page.page_name, + extension: 'md', + })) + } + onChange={(selected) => { + currentEstimateMutation.reset() + const selectedPage = notionPages.find(page => page.page_id === selected.id) + setPreviewNotionPage(selectedPage!) + currentEstimateMutation.mutate() + }} + value={{ + id: previewNotionPage?.page_id || '', + name: previewNotionPage?.page_name || '', + extension: 'md', + }} + /> + } + {dataSourceType === DataSourceType.WEB + && ({ + id: page.source_url, + name: page.title, + extension: 'md', + })) + } + onChange={(selected) => { + currentEstimateMutation.reset() + const selectedPage = websitePages.find(page => page.source_url === selected.id) + setPreviewWebsitePage(selectedPage!) + currentEstimateMutation.mutate() + }} + value={ + { + id: previewWebsitePage?.source_url || '', + name: previewWebsitePage?.title || '', + extension: 'md', + } + } + /> + } + { + currentDocForm !== ChunkingMode.qa + && + } +
+ } + className={cn('flex shrink-0 w-1/2 p-4 pr-0 relative h-full', isMobile && 'w-full max-w-[524px]')} + mainClassName='space-y-6' + > + {currentDocForm === ChunkingMode.qa && estimate?.qa_preview && ( + estimate?.qa_preview.map((item, index) => ( + + + + )) + )} + {currentDocForm === ChunkingMode.text && estimate?.preview && ( + estimate?.preview.map((item, index) => ( + + {item.content} + + )) + )} + {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( + estimate?.preview?.map((item, index) => { + const indexForLabel = index + 1 + return ( + + + {item.child_chunks.map((child, index) => { + const indexForLabel = index + 1 + return ( + + ) + })} + + + ) + }) + )} + {currentEstimateMutation.isIdle && ( +
+
+ +

+ {t('datasetCreation.stepTwo.previewChunkTip')} +

+
+
+ )} + {currentEstimateMutation.isPending && ( +
+ {Array.from({ length: 10 }, (_, i) => ( + + + + + + + + + + + ))} +
+ )} +
) diff --git a/web/app/components/datasets/create/step-two/inputs.tsx b/web/app/components/datasets/create/step-two/inputs.tsx new file mode 100644 index 0000000000..4231f6242d --- /dev/null +++ b/web/app/components/datasets/create/step-two/inputs.tsx @@ -0,0 +1,77 @@ +import type { FC, PropsWithChildren, ReactNode } from 'react' +import { useTranslation } from 'react-i18next' +import type { InputProps } from '@/app/components/base/input' +import Input from '@/app/components/base/input' +import Tooltip from '@/app/components/base/tooltip' +import type { InputNumberProps } from '@/app/components/base/input-number' +import { InputNumber } from '@/app/components/base/input-number' + +const TextLabel: FC = (props) => { + return +} + +const FormField: FC> = (props) => { + return
+ {props.label} + {props.children} +
+} + +export const DelimiterInput: FC = (props) => { + const { t } = useTranslation() + return + {t('datasetCreation.stepTwo.separator')} + + {props.tooltip || t('datasetCreation.stepTwo.separatorTip')} +
+ } + /> +
}> + + +} + +export const MaxLengthInput: FC = (props) => { + const { t } = useTranslation() + return + {t('datasetCreation.stepTwo.maxLength')} +
}> + + +} + +export const OverlapInput: FC = (props) => { + const { t } = useTranslation() + return + {t('datasetCreation.stepTwo.overlap')} + + {t('datasetCreation.stepTwo.overlapTip')} +
+ } + /> +
}> + + +} diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx index 41f3e0abb5..9cbf1a40d1 100644 --- a/web/app/components/datasets/create/step-two/language-select/index.tsx +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import { RiArrowDownSLine } from '@remixicon/react' +import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import cn from '@/utils/classnames' import Popover from '@/app/components/base/popover' import { languages } from '@/i18n/language' @@ -22,25 +22,40 @@ const LanguageSelect: FC = ({ manualClose trigger='click' disabled={disabled} + popupClassName='z-20' htmlContent={ -
+
{languages.filter(language => language.supported).map(({ prompt_name }) => (
onSelect(prompt_name)}>{prompt_name} + className='w-full py-2 px-3 inline-flex items-center justify-between hover:bg-state-base-hover rounded-lg cursor-pointer' + onClick={() => onSelect(prompt_name)} + > + {prompt_name} + {(currentLanguage === prompt_name) && }
))}
} btnElement={ -
- {currentLanguage} - +
+ + {currentLanguage} + +
} - btnClassName={open => cn('!border-0 !px-0 !py-0 !bg-inherit !hover:bg-inherit', open ? 'text-blue-600' : 'text-gray-500')} - className='!w-[120px] h-fit !z-20 !translate-x-0 !left-[-16px]' + btnClassName={() => cn( + '!border-0 rounded-md !px-1.5 !py-1 !mx-1 !bg-components-button-tertiary-bg !hover:bg-components-button-tertiary-bg', + disabled ? 'bg-components-button-tertiary-bg-disabled' : '', + )} + className='!w-[140px] h-fit !z-20 !translate-x-0 !left-1' /> ) } diff --git a/web/app/components/datasets/create/step-two/option-card.tsx b/web/app/components/datasets/create/step-two/option-card.tsx new file mode 100644 index 0000000000..d0efdaabb1 --- /dev/null +++ b/web/app/components/datasets/create/step-two/option-card.tsx @@ -0,0 +1,98 @@ +import { type ComponentProps, type FC, type ReactNode, forwardRef } from 'react' +import Image from 'next/image' +import classNames from '@/utils/classnames' + +const TriangleArrow: FC> = props => ( + + + +) + +type OptionCardHeaderProps = { + icon: ReactNode + title: ReactNode + description: string + isActive?: boolean + activeClassName?: string + effectImg?: string +} + +export const OptionCardHeader: FC = (props) => { + const { icon, title, description, isActive, activeClassName, effectImg } = props + return
+
+ {isActive && effectImg && } +
+
+ {icon} +
+
+
+ +
+
{title}
+
{description}
+
+
+} + +type OptionCardProps = { + icon: ReactNode + className?: string + activeHeaderClassName?: string + title: ReactNode + description: string + isActive?: boolean + actions?: ReactNode + effectImg?: string + onSwitched?: () => void + noHighlight?: boolean + disabled?: boolean +} & Omit, 'title' | 'onClick'> + +export const OptionCard: FC = forwardRef((props, ref) => { + const { icon, className, title, description, isActive, children, actions, activeHeaderClassName, style, effectImg, onSwitched, noHighlight, disabled, ...rest } = props + return
{ + if (!isActive && !disabled) + onSwitched?.() + }} + {...rest} + ref={ref} + > + + {/** Body */} + {isActive && (children || actions) &&
+ {children} + {actions &&
+ {actions} +
+ } +
} +
+}) + +OptionCard.displayName = 'OptionCard' diff --git a/web/app/components/datasets/create/stepper/index.tsx b/web/app/components/datasets/create/stepper/index.tsx new file mode 100644 index 0000000000..317c1a76ee --- /dev/null +++ b/web/app/components/datasets/create/stepper/index.tsx @@ -0,0 +1,27 @@ +import { type FC, Fragment } from 'react' +import type { Step } from './step' +import { StepperStep } from './step' + +export type StepperProps = { + steps: Step[] + activeIndex: number +} + +export const Stepper: FC = (props) => { + const { steps, activeIndex } = props + return
+ {steps.map((step, index) => { + const isLast = index === steps.length - 1 + return ( + + + {!isLast &&
} + + ) + })} +
+} diff --git a/web/app/components/datasets/create/stepper/step.tsx b/web/app/components/datasets/create/stepper/step.tsx new file mode 100644 index 0000000000..c230de1a6e --- /dev/null +++ b/web/app/components/datasets/create/stepper/step.tsx @@ -0,0 +1,46 @@ +import type { FC } from 'react' +import classNames from '@/utils/classnames' + +export type Step = { + name: string +} + +export type StepperStepProps = Step & { + index: number + activeIndex: number +} + +export const StepperStep: FC = (props) => { + const { name, activeIndex, index } = props + const isActive = index === activeIndex + const isDisabled = activeIndex < index + const label = isActive ? `STEP ${index + 1}` : `${index + 1}` + return
+
+
+ {label} +
+
+
{name}
+
+} diff --git a/web/app/components/datasets/create/top-bar/index.tsx b/web/app/components/datasets/create/top-bar/index.tsx new file mode 100644 index 0000000000..20ba7158db --- /dev/null +++ b/web/app/components/datasets/create/top-bar/index.tsx @@ -0,0 +1,41 @@ +import type { FC } from 'react' +import { RiArrowLeftLine } from '@remixicon/react' +import Link from 'next/link' +import { useTranslation } from 'react-i18next' +import { Stepper, type StepperProps } from '../stepper' +import classNames from '@/utils/classnames' + +export type TopbarProps = Pick & { + className?: string +} + +const STEP_T_MAP: Record = { + 1: 'datasetCreation.steps.one', + 2: 'datasetCreation.steps.two', + 3: 'datasetCreation.steps.three', +} + +export const Topbar: FC = (props) => { + const { className, ...rest } = props + const { t } = useTranslation() + return
+ +
+ +
+

+ {t('datasetCreation.steps.header.creation')} +

+ +
+ ({ + name: t(STEP_T_MAP[i + 1]), + }))} + {...rest} + /> +
+
+} diff --git a/web/app/components/datasets/create/website/base/error-message.tsx b/web/app/components/datasets/create/website/base/error-message.tsx index aa337ec4bf..f061c4624e 100644 --- a/web/app/components/datasets/create/website/base/error-message.tsx +++ b/web/app/components/datasets/create/website/base/error-message.tsx @@ -18,7 +18,7 @@ const ErrorMessage: FC = ({ return (
- +
{title}
{errorMsg && ( diff --git a/web/app/components/datasets/create/website/jina-reader/index.tsx b/web/app/components/datasets/create/website/jina-reader/index.tsx index 51d77d7121..1c133f935c 100644 --- a/web/app/components/datasets/create/website/jina-reader/index.tsx +++ b/web/app/components/datasets/create/website/jina-reader/index.tsx @@ -94,7 +94,6 @@ const JinaReader: FC = ({ const waitForCrawlFinished = useCallback(async (jobId: string) => { try { const res = await checkJinaReaderTaskStatus(jobId) as any - console.log('res', res) if (res.status === 'completed') { return { isError: false, diff --git a/web/app/components/datasets/create/website/preview.tsx b/web/app/components/datasets/create/website/preview.tsx index 65abe83ed7..5180a83442 100644 --- a/web/app/components/datasets/create/website/preview.tsx +++ b/web/app/components/datasets/create/website/preview.tsx @@ -18,7 +18,7 @@ const WebsitePreview = ({ const { t } = useTranslation() return ( -
+
{t('datasetCreation.stepOne.pagePreview')} @@ -32,7 +32,7 @@ const WebsitePreview = ({
{payload.source_url}
-
{payload.markdown}
+
{payload.markdown}
) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx index 36216aa7c8..6602244a48 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx @@ -7,7 +7,7 @@ import { import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' -import { DocForm } from '@/models/datasets' +import { ChunkingMode } from '@/models/datasets' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n/language' @@ -32,18 +32,18 @@ const CSV_TEMPLATE_CN = [ ['内容 2'], ] -const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { +const CSVDownload: FC<{ docForm: ChunkingMode }> = ({ docForm }) => { const { t } = useTranslation() const { locale } = useContext(I18n) const { CSVDownloader, Type } = useCSVDownloader() const getTemplate = () => { if (locale === LanguagesSupported[1]) { - if (docForm === DocForm.QA) + if (docForm === ChunkingMode.qa) return CSV_TEMPLATE_QA_CN return CSV_TEMPLATE_CN } - if (docForm === DocForm.QA) + if (docForm === ChunkingMode.qa) return CSV_TEMPLATE_QA_EN return CSV_TEMPLATE_EN } @@ -52,7 +52,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => {
{t('share.generation.csvStructureTitle')}
- {docForm === DocForm.QA && ( + {docForm === ChunkingMode.qa && ( @@ -72,7 +72,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => {
)} - {docForm === DocForm.TEXT && ( + {docForm === ChunkingMode.text && ( @@ -97,7 +97,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { bom={true} data={getTemplate()} > -
+
{t('datasetDocuments.list.batchModal.template')}
diff --git a/web/app/components/datasets/documents/detail/batch-modal/index.tsx b/web/app/components/datasets/documents/detail/batch-modal/index.tsx index 139a364cb4..c666ba6715 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/index.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/index.tsx @@ -7,11 +7,11 @@ import CSVUploader from './csv-uploader' import CSVDownloader from './csv-downloader' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import type { DocForm } from '@/models/datasets' +import type { ChunkingMode } from '@/models/datasets' export type IBatchModalProps = { isShow: boolean - docForm: DocForm + docForm: ChunkingMode onCancel: () => void onConfirm: (file: File) => void } diff --git a/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx b/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx deleted file mode 100644 index 7b510bcf21..0000000000 --- a/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx +++ /dev/null @@ -1,98 +0,0 @@ -import type { CSSProperties, FC } from 'react' -import React from 'react' -import { FixedSizeList as List } from 'react-window' -import InfiniteLoader from 'react-window-infinite-loader' -import SegmentCard from './SegmentCard' -import s from './style.module.css' -import type { SegmentDetailModel } from '@/models/datasets' - -type IInfiniteVirtualListProps = { - hasNextPage?: boolean // Are there more items to load? (This information comes from the most recent API request.) - isNextPageLoading: boolean // Are we currently loading a page of items? (This may be an in-flight flag in your Redux store for example.) - items: Array // Array of items loaded so far. - loadNextPage: () => Promise // Callback function responsible for loading the next page of items. - onClick: (detail: SegmentDetailModel) => void - onChangeSwitch: (segId: string, enabled: boolean) => Promise - onDelete: (segId: string) => Promise - archived?: boolean - embeddingAvailable: boolean -} - -const InfiniteVirtualList: FC = ({ - hasNextPage, - isNextPageLoading, - items, - loadNextPage, - onClick: onClickCard, - onChangeSwitch, - onDelete, - archived, - embeddingAvailable, -}) => { - // If there are more items to be loaded then add an extra row to hold a loading indicator. - const itemCount = hasNextPage ? items.length + 1 : items.length - - // Only load 1 page of items at a time. - // Pass an empty callback to InfiniteLoader in case it asks us to load more than once. - const loadMoreItems = isNextPageLoading ? () => { } : loadNextPage - - // Every row is loaded except for our loading indicator row. - const isItemLoaded = (index: number) => !hasNextPage || index < items.length - - // Render an item or a loading indicator. - const Item = ({ index, style }: { index: number; style: CSSProperties }) => { - let content - if (!isItemLoaded(index)) { - content = ( - <> - {[1, 2, 3].map(v => ( - - ))} - - ) - } - else { - content = items[index].map(segItem => ( - onClickCard(segItem)} - onChangeSwitch={onChangeSwitch} - onDelete={onDelete} - loading={false} - archived={archived} - embeddingAvailable={embeddingAvailable} - /> - )) - } - - return ( -
- {content} -
- ) - } - - return ( - - {({ onItemsRendered, ref }) => ( - - {Item} - - )} - - ) -} -export default InfiniteVirtualList diff --git a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx index 5b76acc936..264d62b68a 100644 --- a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx +++ b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx @@ -6,9 +6,9 @@ import { RiDeleteBinLine, } from '@remixicon/react' import { StatusItem } from '../../list' -import { DocumentTitle } from '../index' +import style from '../../style.module.css' import s from './style.module.css' -import { SegmentIndexTag } from './index' +import { SegmentIndexTag } from './common/segment-index-tag' import cn from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' import Switch from '@/app/components/base/switch' @@ -31,6 +31,22 @@ const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loadi ) } +type DocumentTitleProps = { + extension?: string + name?: string + iconCls?: string + textCls?: string + wrapperCls?: string +} + +const DocumentTitle: FC = ({ extension, name, iconCls, textCls, wrapperCls }) => { + const localExtension = extension?.toLowerCase() || name?.split('.')?.pop()?.toLowerCase() + return
+
+ {name || '--'} +
+} + export type UsageScene = 'doc' | 'hitTesting' type ISegmentCardProps = { diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx new file mode 100644 index 0000000000..085bfddc16 --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx @@ -0,0 +1,134 @@ +import React, { type FC, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCloseLine, + RiExpandDiagonalLine, +} from '@remixicon/react' +import ActionButtons from './common/action-buttons' +import ChunkContent from './common/chunk-content' +import Dot from './common/dot' +import { SegmentIndexTag } from './common/segment-index-tag' +import { useSegmentListContext } from './index' +import type { ChildChunkDetail, ChunkingMode } from '@/models/datasets' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { formatNumber } from '@/utils/format' +import classNames from '@/utils/classnames' +import Divider from '@/app/components/base/divider' +import { formatTime } from '@/utils/time' + +type IChildSegmentDetailProps = { + chunkId: string + childChunkInfo?: Partial & { id: string } + onUpdate: (segmentId: string, childChunkId: string, content: string) => void + onCancel: () => void + docForm: ChunkingMode +} + +/** + * Show all the contents of the segment + */ +const ChildSegmentDetail: FC = ({ + chunkId, + childChunkInfo, + onUpdate, + onCancel, + docForm, +}) => { + const { t } = useTranslation() + const [content, setContent] = useState(childChunkInfo?.content || '') + const { eventEmitter } = useEventEmitterContextContext() + const [loading, setLoading] = useState(false) + const fullScreen = useSegmentListContext(s => s.fullScreen) + const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) + + eventEmitter?.useSubscription((v) => { + if (v === 'update-child-segment') + setLoading(true) + if (v === 'update-child-segment-done') + setLoading(false) + }) + + const handleCancel = () => { + onCancel() + setContent(childChunkInfo?.content || '') + } + + const handleSave = () => { + onUpdate(chunkId, childChunkInfo?.id || '', content) + } + + const wordCountText = useMemo(() => { + const count = content.length + return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}` + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [content.length]) + + const EditTimeText = useMemo(() => { + const timeText = formatTime({ + date: (childChunkInfo?.updated_at ?? 0) * 1000, + dateFormat: 'MM/DD/YYYY h:mm:ss', + }) + return `${t('datasetDocuments.segment.editedAt')} ${timeText}` + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [childChunkInfo?.updated_at]) + + return ( +
+
+
+
{t('datasetDocuments.segment.editChildChunk')}
+
+ + + {wordCountText} + + + {EditTimeText} + +
+
+
+ {fullScreen && ( + <> + + + + )} +
+ +
+
+ +
+
+
+
+
+ setContent(content)} + isEditMode={true} + /> +
+
+ {!fullScreen && ( +
+ +
+ )} +
+ ) +} + +export default React.memo(ChildSegmentDetail) diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx new file mode 100644 index 0000000000..1615ea98cf --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx @@ -0,0 +1,195 @@ +import { type FC, useMemo, useState } from 'react' +import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { EditSlice } from '../../../formatted-text/flavours/edit-slice' +import { useDocumentContext } from '../index' +import { FormattedText } from '../../../formatted-text/formatted' +import Empty from './common/empty' +import FullDocListSkeleton from './skeleton/full-doc-list-skeleton' +import { useSegmentListContext } from './index' +import type { ChildChunkDetail } from '@/models/datasets' +import Input from '@/app/components/base/input' +import classNames from '@/utils/classnames' +import Divider from '@/app/components/base/divider' +import { formatNumber } from '@/utils/format' + +type IChildSegmentCardProps = { + childChunks: ChildChunkDetail[] + parentChunkId: string + handleInputChange?: (value: string) => void + handleAddNewChildChunk?: (parentChunkId: string) => void + enabled: boolean + onDelete?: (segId: string, childChunkId: string) => Promise + onClickSlice?: (childChunk: ChildChunkDetail) => void + total?: number + inputValue?: string + onClearFilter?: () => void + isLoading?: boolean + focused?: boolean +} + +const ChildSegmentList: FC = ({ + childChunks, + parentChunkId, + handleInputChange, + handleAddNewChildChunk, + enabled, + onDelete, + onClickSlice, + total, + inputValue, + onClearFilter, + isLoading, + focused = false, +}) => { + const { t } = useTranslation() + const parentMode = useDocumentContext(s => s.parentMode) + const currChildChunk = useSegmentListContext(s => s.currChildChunk) + + const [collapsed, setCollapsed] = useState(true) + + const toggleCollapse = () => { + setCollapsed(!collapsed) + } + + const isParagraphMode = useMemo(() => { + return parentMode === 'paragraph' + }, [parentMode]) + + const isFullDocMode = useMemo(() => { + return parentMode === 'full-doc' + }, [parentMode]) + + const contentOpacity = useMemo(() => { + return (enabled || focused) ? '' : 'opacity-50 group-hover/card:opacity-100' + }, [enabled, focused]) + + const totalText = useMemo(() => { + const isSearch = inputValue !== '' && isFullDocMode + if (!isSearch) { + const text = isFullDocMode + ? !total + ? '--' + : formatNumber(total) + : formatNumber(childChunks.length) + const count = isFullDocMode + ? text === '--' + ? 0 + : total + : childChunks.length + return `${text} ${t('datasetDocuments.segment.childChunks', { count })}` + } + else { + const text = !total ? '--' : formatNumber(total) + const count = text === '--' ? 0 : total + return `${count} ${t('datasetDocuments.segment.searchResults', { count })}` + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isFullDocMode, total, childChunks.length, inputValue]) + + return ( +
+ {isFullDocMode ? : null} +
+
{ + event.stopPropagation() + toggleCollapse() + }} + > + { + isParagraphMode + ? collapsed + ? ( + + ) + : () + : null + } + {totalText} + · + +
+ {isFullDocMode + ? handleInputChange?.(e.target.value)} + onClear={() => handleInputChange?.('')} + /> + : null} +
+ {isLoading ? : null} + {((isFullDocMode && !isLoading) || !collapsed) + ?
+ {isParagraphMode && ( +
+ +
+ )} + {childChunks.length > 0 + ? + {childChunks.map((childChunk) => { + const edited = childChunk.updated_at !== childChunk.created_at + const focused = currChildChunk?.childChunkInfo?.id === childChunk.id + return onDelete?.(childChunk.segment_id, childChunk.id)} + labelClassName={focused ? 'bg-state-accent-solid text-text-primary-on-surface' : ''} + labelInnerClassName={'text-[10px] font-semibold align-bottom leading-6'} + contentClassName={classNames('!leading-6', focused ? 'bg-state-accent-hover-alt text-text-primary' : '')} + showDivider={false} + onClick={(e) => { + e.stopPropagation() + onClickSlice?.(childChunk) + }} + offsetOptions={({ rects }) => { + return { + mainAxis: isFullDocMode ? -rects.floating.width : 12 - rects.floating.width, + crossAxis: (20 - rects.floating.height) / 2, + } + }} + /> + })} + + : inputValue !== '' + ?
+ +
+ : null + } +
+ : null} +
+ ) +} + +export default ChildSegmentList diff --git a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx new file mode 100644 index 0000000000..1238d98a9c --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx @@ -0,0 +1,86 @@ +import React, { type FC, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { useKeyPress } from 'ahooks' +import { useDocumentContext } from '../../index' +import Button from '@/app/components/base/button' +import { getKeyboardKeyCodeBySystem, getKeyboardKeyNameBySystem } from '@/app/components/workflow/utils' + +type IActionButtonsProps = { + handleCancel: () => void + handleSave: () => void + loading: boolean + actionType?: 'edit' | 'add' + handleRegeneration?: () => void + isChildChunk?: boolean +} + +const ActionButtons: FC = ({ + handleCancel, + handleSave, + loading, + actionType = 'edit', + handleRegeneration, + isChildChunk = false, +}) => { + const { t } = useTranslation() + const mode = useDocumentContext(s => s.mode) + const parentMode = useDocumentContext(s => s.parentMode) + + useKeyPress(['esc'], (e) => { + e.preventDefault() + handleCancel() + }) + + useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.s`, (e) => { + e.preventDefault() + if (loading) + return + handleSave() + } + , { exactMatch: true, useCapture: true }) + + const isParentChildParagraphMode = useMemo(() => { + return mode === 'hierarchical' && parentMode === 'paragraph' + }, [mode, parentMode]) + + return ( +
+ + {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk) + ? + : null + } + +
+ ) +} + +ActionButtons.displayName = 'ActionButtons' + +export default React.memo(ActionButtons) diff --git a/web/app/components/datasets/documents/detail/completed/common/add-another.tsx b/web/app/components/datasets/documents/detail/completed/common/add-another.tsx new file mode 100644 index 0000000000..444560e55f --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/add-another.tsx @@ -0,0 +1,32 @@ +import React, { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import classNames from '@/utils/classnames' +import Checkbox from '@/app/components/base/checkbox' + +type AddAnotherProps = { + className?: string + isChecked: boolean + onCheck: () => void +} + +const AddAnother: FC = ({ + className, + isChecked, + onCheck, +}) => { + const { t } = useTranslation() + + return ( +
+ + {t('datasetDocuments.segment.addAnother')} +
+ ) +} + +export default React.memo(AddAnother) diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx new file mode 100644 index 0000000000..3dd3689b64 --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx @@ -0,0 +1,103 @@ +import React, { type FC } from 'react' +import { RiArchive2Line, RiCheckboxCircleLine, RiCloseCircleLine, RiDeleteBinLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import Divider from '@/app/components/base/divider' +import classNames from '@/utils/classnames' +import Confirm from '@/app/components/base/confirm' + +const i18nPrefix = 'dataset.batchAction' +type IBatchActionProps = { + className?: string + selectedIds: string[] + onBatchEnable: () => void + onBatchDisable: () => void + onBatchDelete: () => Promise + onArchive?: () => void + onCancel: () => void +} + +const BatchAction: FC = ({ + className, + selectedIds, + onBatchEnable, + onBatchDisable, + onArchive, + onBatchDelete, + onCancel, +}) => { + const { t } = useTranslation() + const [isShowDeleteConfirm, { + setTrue: showDeleteConfirm, + setFalse: hideDeleteConfirm, + }] = useBoolean(false) + const [isDeleting, { + setTrue: setIsDeleting, + }] = useBoolean(false) + + const handleBatchDelete = async () => { + setIsDeleting() + await onBatchDelete() + hideDeleteConfirm() + } + return ( +
+
+
+ + {selectedIds.length} + + {t(`${i18nPrefix}.selected`)} +
+ +
+ + +
+
+ + +
+ {onArchive && ( +
+ + +
+ )} +
+ + +
+ + + +
+ { + isShowDeleteConfirm && ( + + ) + } +
+ ) +} + +export default React.memo(BatchAction) diff --git a/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx b/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx new file mode 100644 index 0000000000..e6403fa12f --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx @@ -0,0 +1,192 @@ +import React, { useEffect, useRef, useState } from 'react' +import type { ComponentProps, FC } from 'react' +import { useTranslation } from 'react-i18next' +import { ChunkingMode } from '@/models/datasets' +import classNames from '@/utils/classnames' + +type IContentProps = ComponentProps<'textarea'> + +const Textarea: FC = React.memo(({ + value, + placeholder, + className, + disabled, + ...rest +}) => { + return ( +