diff --git a/api/.importlinter b/api/.importlinter index 2b4a3a5bd6..ff0577222e 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -227,6 +227,9 @@ ignore_imports = core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service + core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.llm.node -> models.dataset core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8fbbc51e21..30e4ed1119 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None + summary_index_setting: dict[str, Any] | None = None partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None @@ -288,7 +289,14 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - query = ConsoleDatasetListQuery.model_validate(request.args.to_dict()) + # Convert query parameters to dict, handling list parameters correctly + query_params: dict[str, str | list[str]] = dict(request.args.to_dict()) + # Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value) + if "ids" in request.args: + query_params["ids"] = request.args.getlist("ids") + if "tag_ids" in request.args: + query_params["tag_ids"] = request.args.getlist("tag_ids") + query = ConsoleDatasetListQuery.model_validate(query_params) # provider = request.args.get("provider", default="vendor") if query.ids: datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 57fb9abf29..6e3c0db8a3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService +from tasks.generate_summary_index_task import generate_summary_index_task from ..app.error import ( ProviderModelCurrentlyNotSupportError, @@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel): name: str +class GenerateSummaryPayload(BaseModel): + document_list: list[str] + + class DocumentBatchDownloadZipPayload(BaseModel): """Request payload for bulk downloading documents as a zip archive.""" @@ -125,6 +130,7 @@ register_schema_models( RetrievalModel, DocumentRetryPayload, DocumentRenamePayload, + GenerateSummaryPayload, DocumentBatchDownloadZipPayload, ) @@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource): paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items + + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=current_tenant_id, + ) + if fetch: for document in documents: completed_segments = ( @@ -797,6 +810,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -832,6 +846,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response, 200 @@ -1255,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource): "input_data": log.input_data, "datasource_node_id": log.datasource_node_id, }, 200 + + +@console_ns.route("/datasets//documents/generate-summary") +class DocumentGenerateSummaryApi(Resource): + @console_ns.doc("generate_summary_for_documents") + @console_ns.doc(description="Generate summary index for documents") + @console_ns.doc(params={"dataset_id": "Dataset ID"}) + @console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__]) + @console_ns.response(200, "Summary generation started successfully") + @console_ns.response(400, "Invalid request or dataset configuration") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self, dataset_id): + """ + Generate summary index for specified documents. + + This endpoint checks if the dataset configuration supports summary generation + (indexing_technique must be 'high_quality' and summary_index_setting.enable must be true), + then asynchronously generates summary indexes for the provided documents. + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + if not current_user.is_dataset_editor: + raise Forbidden() + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Validate request payload + payload = GenerateSummaryPayload.model_validate(console_ns.payload or {}) + document_list = payload.document_list + + if not document_list: + from werkzeug.exceptions import BadRequest + + raise BadRequest("document_list cannot be empty.") + + # Check if dataset configuration supports summary generation + if dataset.indexing_technique != "high_quality": + raise ValueError( + f"Summary generation is only available for 'high_quality' indexing technique. " + f"Current indexing technique: {dataset.indexing_technique}" + ) + + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.") + + # Verify all documents exist and belong to the dataset + documents = DocumentService.get_documents_by_ids(dataset_id, document_list) + + if len(documents) != len(document_list): + found_ids = {doc.id for doc in documents} + missing_ids = set(document_list) - found_ids + raise NotFound(f"Some documents not found: {list(missing_ids)}") + + # Dispatch async tasks for each document + for document in documents: + # Skip qa_model documents as they don't generate summaries + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + continue + + # Dispatch async task + generate_summary_index_task.delay(dataset_id, document.id) + logger.info( + "Dispatched summary generation task for document %s in dataset %s", + document.id, + dataset_id, + ) + + return {"result": "success"}, 200 + + +@console_ns.route("/datasets//documents//summary-status") +class DocumentSummaryStatusApi(DocumentResource): + @console_ns.doc("get_document_summary_status") + @console_ns.doc(description="Get summary index generation status for a document") + @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @console_ns.response(200, "Summary status retrieved successfully") + @console_ns.response(404, "Document not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + """ + Get summary index generation status for a document. + + Returns: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + document_id = str(document_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Get summary status detail from service + from services.summary_index_service import SummaryIndexService + + result = SummaryIndexService.get_document_summary_status_detail( + document_id=document_id, + dataset_id=dataset_id, + ) + + return result, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 08e1ddd3e0..23a668112d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +def _get_segment_with_summary(segment, dataset_id): + """Helper function to marshal segment and add summary information.""" + from services.summary_index_service import SummaryIndexService + + segment_dict = dict(marshal(segment, segment_fields)) + # Query summary for this segment (only enabled summaries) + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) + segment_dict["summary"] = summary.summary_content if summary else None + return segment_dict + + class SegmentListQuery(BaseModel): limit: int = Field(default=20, ge=1, le=100) status: list[str] = Field(default_factory=list) @@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class BatchImportPayload(BaseModel): @@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource): segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + # Query summaries for all segments in this page (batch query for efficiency) + segment_ids = [segment.id for segment in segments.items] + summaries = {} + if segment_ids: + from services.summary_index_service import SummaryIndexService + + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + # Only include enabled summaries (already filtered by service) + summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()} + + # Add summary to each segment + segments_with_summary = [] + for segment in segments.items: + segment_dict = dict(marshal(segment, segment_fields)) + segment_dict["summary"] = summaries.get(segment.id) + segments_with_summary.append(segment_dict) + response = { - "data": marshal(segments.items, segment_fields), + "data": segments_with_summary, "limit": limit, "total": segments.total, "total_pages": segments.pages, @@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource): payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) segment = SegmentService.create_segment(payload_dict, document, dataset) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @console_ns.route("/datasets//documents//segments/") @@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) + + # Update segment (summary update with change detection is handled in SegmentService.update_segment) segment = SegmentService.update_segment( SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @setup_required @login_required diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 932cb4fcce..e62be13c2f 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,6 +1,13 @@ -from flask_restx import Resource +from flask_restx import Resource, fields from controllers.common.schema import register_schema_model +from fields.hit_testing_fields import ( + child_chunk_fields, + document_fields, + files_fields, + hit_testing_record_fields, + segment_fields, +) from libs.login import login_required from .. import console_ns @@ -14,13 +21,45 @@ from ..wraps import ( register_schema_model(console_ns, HitTestingPayload) +def _get_or_create_model(model_name: str, field_def): + """Get or create a flask_restx model to avoid dict type issues in Swagger.""" + existing = console_ns.models.get(model_name) + if existing is None: + existing = console_ns.model(model_name, field_def) + return existing + + +# Register models for flask_restx to avoid dict type issues in Swagger +document_model = _get_or_create_model("HitTestingDocument", document_fields) + +segment_fields_copy = segment_fields.copy() +segment_fields_copy["document"] = fields.Nested(document_model) +segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + +child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) +files_model = _get_or_create_model("HitTestingFile", files_fields) + +hit_testing_record_fields_copy = hit_testing_record_fields.copy() +hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) +hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) +hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) +hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) + +# Response model for hit testing API +hit_testing_response_fields = { + "query": fields.String, + "records": fields.List(fields.Nested(hit_testing_record_model)), +} +hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + + @console_ns.route("/datasets//hit-testing") class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc("test_dataset_retrieval") @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully") + @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 28864a140a..c11f64585a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel): retrieval_model: RetrievalModel | None = None embedding_model: str | None = None embedding_model_provider: str | None = None + summary_index_setting: dict | None = None class DatasetUpdatePayload(BaseModel): @@ -217,6 +218,7 @@ class DatasetListApi(DatasetApiResource): embedding_model_provider=payload.embedding_model_provider, embedding_model_name=payload.embedding_model, retrieval_model=payload.retrieval_model, + summary_index_setting=payload.summary_index_setting, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c85c1cf81e..a01524f1bc 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( Segmentation, ) from services.file_service import FileService +from services.summary_index_service import SummaryIndexService class DocumentTextCreatePayload(BaseModel): @@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource): ) documents = paginated_documents.items + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=tenant_id, + ) + response = { "data": marshal(documents, document_fields), "has_more": len(documents) == query_params.limit, @@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource): if metadata not in self.METADATA_CHOICES: raise InvalidMetadataError(f"Invalid metadata value: {metadata}") + # Calculate summary_index_status if needed + summary_index_status = None + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + if has_summary_index and document.need_summary is True: + summary_index_status = SummaryIndexService.get_document_summary_index_status( + document_id=document_id, + dataset_id=dataset_id, + tenant_id=tenant_id, + ) + if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": @@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 74c6d2eca6..d1e2f16b6f 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC): "document_name": resource["document_name"], "score": resource["score"], "content": resource["content"], + "summary": resource.get("summary"), } ) metadata["retriever_resources"] = updated_resources diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index d4093b5245..b1ba3c3e2a 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): content: str + summary: str | None = None child_chunks: list[str] | None = None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f1b50f360b..e172e88298 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -311,14 +311,18 @@ class IndexingRunner: qa_preview_texts: list[QAPreviewDetail] = [] total_segments = 0 + # doc_form represents the segmentation method (general, parent-child, QA) index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() + # one extract_setting is one source document for extract_setting in extract_settings: # extract processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) + # Extract document content text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + # Cleaning and segmentation documents = index_processor.transform( text_docs, current_user=None, @@ -361,6 +365,12 @@ class IndexingRunner: if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[]) + + # Generate summary preview + summary_index_setting = tmp_processing_rule.get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable") and preview_texts: + preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ec2b7f2d44..d46cf049dd 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex You should edit the prompt according to the IDEAL OUTPUT.""" INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" + +DEFAULT_GENERATOR_SUMMARY_PROMPT = ( + """Summarize the following content. Extract only the key information and main points. """ + """Remove redundant details. + +Requirements: +1. Write a concise summary in plain text +2. Use the same language as the input content +3. Focus on important facts, concepts, and details +4. If images are included, describe their key information +5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" +6. Write directly without extra words + +Output only the summary text. Start summarizing now: + +""" +) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 8ec1ce6242..91c16ce079 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -24,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding +from models.dataset import ( + ChildChunk, + Dataset, + DocumentSegment, + DocumentSegmentSummary, + SegmentAttachmentBinding, +) from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService @@ -389,15 +395,15 @@ class RetrievalService: .all() } - records = [] - include_segment_ids = set() - segment_child_map = {} - valid_dataset_documents = {} image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} + summary_segment_ids = set() # Track segments retrieved via summary + summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score + + # First pass: collect all document IDs and identify summary documents for document in documents: document_id = document.metadata.get("document_id") if document_id not in dataset_documents: @@ -408,16 +414,39 @@ class RetrievalService: continue valid_dataset_documents[document_id] = dataset_document + doc_id = document.metadata.get("doc_id") or "" + doc_to_document_map[doc_id] = document + + # Check if this is a summary document + is_summary = document.metadata.get("is_summary", False) + if is_summary: + # For summary documents, find the original chunk via original_chunk_id + original_chunk_id = document.metadata.get("original_chunk_id") + if original_chunk_id: + summary_segment_ids.add(original_chunk_id) + # Save summary's score for later use + summary_score = document.metadata.get("score") + if summary_score is not None: + try: + summary_score_float = float(summary_score) + # If the same segment has multiple summary hits, take the highest score + if original_chunk_id not in summary_score_map: + summary_score_map[original_chunk_id] = summary_score_float + else: + summary_score_map[original_chunk_id] = max( + summary_score_map[original_chunk_id], summary_score_float + ) + except (ValueError, TypeError): + # Skip invalid score values + pass + continue # Skip adding to other lists for summary documents + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: child_index_node_ids.append(doc_id) else: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: @@ -433,6 +462,7 @@ class RetrievalService: attachment_map: dict[str, list[dict[str, Any]]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} + segment_summary_map: dict[str, str] = {} # Map segment_id to summary content with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) @@ -447,6 +477,7 @@ class RetrievalService: doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) else: doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() @@ -470,6 +501,7 @@ class RetrievalService: index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] + if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -481,6 +513,40 @@ class RetrievalService: if index_node_segments: segments.extend(index_node_segments) + # Handle summary documents: query segments by original_chunk_id + if summary_segment_ids: + summary_segment_ids_list = list(summary_segment_ids) + summary_segment_stmt = select(DocumentSegment).where( + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id.in_(summary_segment_ids_list), + ) + summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore + segments.extend(summary_segments) + # Add summary segment IDs to segment_ids for summary query + for seg in summary_segments: + if seg.id not in segment_ids: + segment_ids.append(seg.id) + + # Batch query summaries for segments retrieved via summary (only enabled summaries) + if summary_segment_ids: + summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)), + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries + ) + .all() + ) + for summary in summaries: + if summary.summary_content: + segment_summary_map[summary.chunk_id] = summary.summary_content + + include_segment_ids = set() + segment_child_map: dict[str, dict[str, Any]] = {} + records: list[dict[str, Any]] = [] + for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) @@ -489,30 +555,44 @@ class RetrievalService: if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) + # Check if this segment was retrieved via summary + # Use summary score as base score if available, otherwise 0.0 + max_score = summary_score_map.get(segment.id, 0.0) + if child_chunks or attachment_infos: child_chunk_details = [] - max_score = 0.0 for child_chunk in child_chunks: - document = doc_to_document_map[child_chunk.index_node_id] + child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) + if child_document: + child_score = child_document.metadata.get("score", 0.0) + else: + child_score = 0.0 child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, - "score": document.metadata.get("score", 0.0) if document else 0.0, + "score": child_score, } child_chunk_details.append(child_chunk_detail) - max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + max_score = max(max_score, child_score) for attachment_info in attachment_infos: - file_document = doc_to_document_map[attachment_info["id"]] - max_score = max( - max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 - ) + file_document = doc_to_document_map.get(attachment_info["id"]) + if file_document: + max_score = max(max_score, file_document.metadata.get("score", 0.0)) map_detail = { "max_score": max_score, "child_chunks": child_chunk_details, } segment_child_map[segment.id] = map_detail + else: + # No child chunks or attachments, use summary score if available + summary_score = summary_score_map.get(segment.id) + if summary_score is not None: + segment_child_map[segment.id] = { + "max_score": summary_score, + "child_chunks": [], + } record: dict[str, Any] = { "segment": segment, } @@ -520,14 +600,23 @@ class RetrievalService: else: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) - max_score = 0.0 - segment_document = doc_to_document_map.get(segment.index_node_id) - if segment_document: - max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Check if this segment was retrieved via summary + # Use summary score if available (summary retrieval takes priority) + max_score = summary_score_map.get(segment.id, 0.0) + + # If not retrieved via summary, use original segment's score + if segment.id not in summary_score_map: + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Also consider attachment scores for attachment_info in attachment_infos: file_doc = doc_to_document_map.get(attachment_info["id"]) if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { "segment": segment, "score": max_score, @@ -576,9 +665,16 @@ class RetrievalService: else None ) + # Extract summary if this segment was retrieved via summary + summary_content = segment_summary_map.get(segment.id) + # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks_list, score=score, files=files + segment=segment, + child_chunks=child_chunks_list, + score=score, + files=files, + summary=summary_content, ) result.append(retrieval_segment) diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index b54a37b49e..f6834ab87b 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel): child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None files: list[dict[str, str | int]] | None = None + summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 9f66cd9a03..aec5c353f8 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel): doc_metadata: dict[str, Any] | None = None title: str | None = None files: list[dict[str, Any]] | None = None + summary: str | None = None diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index e36b54eedd..151a3de7d9 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse import httpx from configs import dify_config +from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType @@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC): def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError + @abstractmethod + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each segment in preview_texts, generate a summary using LLM and attach it to the segment. + The summary can be stored in a new attribute, e.g., summary. + This method should be implemented by subclasses. + """ + raise NotImplementedError + @abstractmethod def load( self, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index cf68cff7dc..ab91e29145 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,9 +1,27 @@ """Paragraph index processor.""" +import logging +import re import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, cast +logger = logging.getLogger(__name__) + +from core.entities.knowledge_entities import PreviewDetail +from core.file import File, FileTransferMethod, FileType, file_manager +from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService @@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.nodes.llm import llm_utils +from extensions.ext_database import db +from factories.file_factory import build_from_mapping from libs import helper +from models import UploadFile from models.account import Account -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService class ParagraphIndexProcessor(BaseIndexProcessor): @@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.add_texts(documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: @@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor): } else: raise ValueError("Chunks is not a list") + + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each segment, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item.""" + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts + + @staticmethod + def generate_summary( + tenant_id: str, + text: str, + summary_index_setting: dict | None = None, + segment_id: str | None = None, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, + and supports vision models by including images from the segment attachments or text content. + + Args: + tenant_id: Tenant ID + text: Text content to summarize + summary_index_setting: Summary index configuration + segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + """ + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("summary_index_setting is required and must be enabled to generate summary.") + + model_name = summary_index_setting.get("model_name") + model_provider_name = summary_index_setting.get("model_provider_name") + summary_prompt = summary_index_setting.get("summary_prompt") + + if not model_name or not model_provider_name: + raise ValueError("model_name and model_provider_name are required in summary_index_setting") + + # Import default summary prompt + if not summary_prompt: + summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id, model_provider_name, ModelType.LLM + ) + model_instance = ModelInstance(provider_model_bundle, model_name) + + # Get model schema to check if vision is supported + model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials) + supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features + + # Extract images if model supports vision + image_files = [] + if supports_vision: + # First, try to get images from SegmentAttachmentBinding (preferred method) + if segment_id: + image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id) + + # If no images from attachments, fall back to extracting from text + if not image_files: + image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text) + + # Build prompt messages + prompt_messages = [] + + if image_files: + # If we have images, create a UserPromptMessage with both text and images + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + + # Add images first + for file in image_files: + try: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + prompt_message_contents.append(file_content) + except Exception as e: + logger.warning("Failed to convert image file to prompt message content: %s", str(e)) + continue + + # Add text content + if prompt_message_contents: # Only add text if we successfully added images + prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}")) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + # If image conversion failed, fall back to text-only + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + else: + # No images, use simple text prompt + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + + result = model_instance.invoke_llm( + prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False + ) + + # Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator + if not isinstance(result, LLMResult): + raise ValueError("Expected LLMResult when stream=False") + + summary_content = getattr(result.message, "content", "") + usage = result.usage + + # Deduct quota for summary generation (same as workflow nodes) + try: + llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + except Exception as e: + # Log but don't fail summary generation if quota deduction fails + logger.warning("Failed to deduct quota for summary generation: %s", str(e)) + + return summary_content, usage + + @staticmethod + def _extract_images_from_text(tenant_id: str, text: str) -> list[File]: + """ + Extract images from markdown text and convert them to File objects. + + Args: + tenant_id: Tenant ID + text: Text content that may contain markdown image links + + Returns: + List of File objects representing images found in the text + """ + # Extract markdown images using regex pattern + pattern = r"!\[.*?\]\((.*?)\)" + images = re.findall(pattern, text) + + if not images: + return [] + + upload_file_id_list = [] + + for image in images: + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + # Tool files are handled differently, skip for now + continue + + if not upload_file_id_list: + return [] + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = ( + db.session.query(UploadFile) + .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) + .all() + ) + + # Create File objects from UploadFile records + file_objects = [] + for upload_file in upload_files: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + mapping = { + "upload_file_id": upload_file.id, + "transfer_method": FileTransferMethod.LOCAL_FILE.value, + "type": FileType.IMAGE.value, + } + + try: + file_obj = build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects + + @staticmethod + def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]: + """ + Extract images from SegmentAttachmentBinding table (preferred method). + This matches how DatasetRetrieval gets segment attachments. + + Args: + tenant_id: Tenant ID + segment_id: Segment ID to fetch attachments for + + Returns: + List of File objects representing images found in segment attachments + """ + from sqlalchemy import select + + # Query attachments from SegmentAttachmentBinding table + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment_id, + SegmentAttachmentBinding.tenant_id == tenant_id, + ) + ).all() + + if not attachments_with_bindings: + return [] + + file_objects = [] + for _, upload_file in attachments_with_bindings: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + try: + # Create File object directly (similar to DatasetRetrieval) + file_obj = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 0366f3259f..961df2e50c 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,11 +1,14 @@ """Paragraph index processor.""" import json +import logging import uuid from collections.abc import Mapping from typing import Any from configs import dify_config +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) class ParentChildIndexProcessor(BaseIndexProcessor): @@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") @@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor): "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + + Note: For parent-child structure, we only generate summaries for parent chunks. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item (parent chunk).""" + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + ) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + ) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 1183d5fbd7..272d2ed351 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -11,6 +11,8 @@ import pandas as pd from flask import Flask, current_app from werkzeug.datastructures import FileStorage +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.account import Account -from models.dataset import Dataset +from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) @@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor): vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Note: qa_model doesn't generate summaries, but we clean them for completeness + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor): "total_segments": len(qa_chunks.qa_chunks), } + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + QA model doesn't generate summaries, so this method returns preview_texts unchanged. + + Note: QA model uses question-answer pairs, which don't require summary generation. + """ + # QA model doesn't generate summaries, return as-is + return preview_texts + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f8f85d141a..541c241ae5 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -236,20 +236,24 @@ class DatasetRetrieval: if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if vision_enabled: attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -316,6 +320,9 @@ class DatasetRetrieval: source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f96510fb45..057ec41f65 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if self.return_resource: for record in records: @@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 3daca90b9b..bfeb9b5b79 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData): type: str = "knowledge-index" chunk_structure: str index_chunk_variable_selector: list[str] + indexing_technique: str | None = None + summary_index_setting: dict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 17ca4bef7b..b88c2d510f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,9 +1,11 @@ +import concurrent.futures import datetime import logging import time from collections.abc import Mapping from typing import Any +from flask import current_app from sqlalchemy import func, select from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.runtime import VariablePool from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary +from services.summary_index_service import SummaryIndexService +from tasks.generate_summary_index_task import generate_summary_index_task from .entities import KnowledgeIndexNodeData from .exc import ( @@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): # index knowledge try: if is_preview: - outputs = self._get_preview_output(node_data.chunk_structure, chunks) + # Preview mode: generate summaries for chunks directly without saving to database + # Format preview and generate summaries on-the-fly + # Get indexing_technique and summary_index_setting from node_data (workflow graph config) + # or fallback to dataset if not available in node_data + indexing_technique = node_data.indexing_technique or dataset.indexing_technique + summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + + outputs = self._get_preview_output_with_summaries( + node_data.chunk_structure, + chunks, + dataset=dataset, + indexing_technique=indexing_technique, + summary_index_setting=summary_index_setting, + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): ) .scalar() ) + # Update need_summary based on dataset's summary_index_setting + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + document.need_summary = True + else: + document.need_summary = False db.session.add(document) # update document segment status db.session.query(DocumentSegment).where( @@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): db.session.commit() + # Generate summary index if enabled + self._handle_summary_index_generation(dataset, document, variable_pool) + return { "dataset_id": ds_id_value, "dataset_name": dataset_name_value, @@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): "display_status": "completed", } - def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]: + def _handle_summary_index_generation( + self, + dataset: Dataset, + document: Document, + variable_pool: VariablePool, + ) -> None: + """ + Handle summary index generation based on mode (debug/preview or production). + + Args: + dataset: Dataset containing the document + document: Document to generate summaries for + variable_pool: Variable pool to check invoke_from + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + return + + # Skip qa_model documents + if document.doc_form == "qa_model": + return + + # Determine if in preview/debug mode + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER + + if is_preview: + try: + # Query segments that need summary generation + query = db.session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return + + # Filter segments based on mode + segments_to_process = [] + for segment in segments: + # Skip if summary already exists + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed") + .first() + ) + if existing_summary: + continue + + # For parent-child mode, all segments are parent chunks, so process all + segments_to_process.append(segment) + + if not segments_to_process: + logger.info("No segments need summary generation for document %s", document.id) + return + + # Use ThreadPoolExecutor for concurrent generation + flask_app = current_app._get_current_object() # type: ignore + max_workers = min(10, len(segments_to_process)) # Limit to 10 workers + + def process_segment(segment: DocumentSegment) -> None: + """Process a single segment in a thread with Flask app context.""" + with flask_app.app_context(): + try: + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting) + except Exception: + logger.exception( + "Failed to generate summary for segment %s", + segment.id, + ) + # Continue processing other segments + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_segment, segment) for segment in segments_to_process] + # Wait for all tasks to complete + concurrent.futures.wait(futures) + + logger.info( + "Successfully generated summary index for %s segments in document %s", + len(segments_to_process), + document.id, + ) + except Exception: + logger.exception("Failed to generate summary index for document %s", document.id) + # Don't fail the entire indexing process if summary generation fails + else: + # Production mode: asynchronous generation + logger.info( + "Queuing summary index generation task for document %s (production mode)", + document.id, + ) + try: + generate_summary_index_task.delay(dataset.id, document.id, None) + logger.info("Summary index generation task queued for document %s", document.id) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document.id, + ) + # Don't fail the entire indexing process if task queuing fails + + def _get_preview_output_with_summaries( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset, + indexing_technique: str | None = None, + summary_index_setting: dict | None = None, + ) -> Mapping[str, Any]: + """ + Generate preview output with summaries for chunks in preview mode. + This method generates summaries on-the-fly without saving to database. + + Args: + chunk_structure: Chunk structure type + chunks: Chunks to generate preview for + dataset: Dataset object (for tenant_id) + indexing_technique: Indexing technique from node config or dataset + summary_index_setting: Summary index setting from node config or dataset + """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - return index_processor.format_preview(chunks) + preview_output = index_processor.format_preview(chunks) + + # Check if summary index is enabled + if indexing_technique != "high_quality": + return preview_output + + if not summary_index_setting or not summary_index_setting.get("enable"): + return preview_output + + # Generate summaries for chunks + if "preview" in preview_output and isinstance(preview_output["preview"], list): + chunk_count = len(preview_output["preview"]) + logger.info( + "Generating summaries for %s chunks in preview mode (dataset: %s)", + chunk_count, + dataset.id, + ) + # Use ParagraphIndexProcessor's generate_summary method + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + # Get Flask app for application context in worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def generate_summary_for_chunk(preview_item: dict) -> None: + """Generate summary for a single chunk.""" + if "content" in preview_item: + # Set Flask application context in worker thread + if flask_app: + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + ) + if summary: + preview_item["summary"] = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + ) + if summary: + preview_item["summary"] = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_output["preview"])) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor: + futures = [ + executor.submit(generate_summary_for_chunk, preview_item) + for preview_item in preview_output["preview"] + ] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode, if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise KnowledgeIndexNodeError(error_summary) + + completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None) + logger.info( + "Completed summary generation for preview chunks: %s/%s succeeded", + completed_count, + len(preview_output["preview"]), + ) + + return preview_output + + def _get_preview_output( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset | None = None, + variable_pool: VariablePool | None = None, + ) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + preview_output = index_processor.format_preview(chunks) + + # If dataset is provided, try to enrich preview with summaries + if dataset and variable_pool: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document: + # Query summaries for this document + summaries = ( + db.session.query(DocumentSegmentSummary) + .filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + .all() + ) + + if summaries: + # Create a map of segment content to summary for matching + # Use content matching as chunks in preview might not be indexed yet + summary_by_content = {} + for summary in summaries: + segment = ( + db.session.query(DocumentSegment) + .filter_by(id=summary.chunk_id, dataset_id=dataset.id) + .first() + ) + if segment: + # Normalize content for matching (strip whitespace) + normalized_content = segment.content.strip() + summary_by_content[normalized_content] = summary.summary_content + + # Enrich preview with summaries by content matching + if "preview" in preview_output and isinstance(preview_output["preview"], list): + matched_count = 0 + for preview_item in preview_output["preview"]: + if "content" in preview_item: + # Normalize content for matching + normalized_chunk_content = preview_item["content"].strip() + if normalized_chunk_content in summary_by_content: + preview_item["summary"] = summary_by_content[normalized_chunk_content] + matched_count += 1 + + if matched_count > 0: + logger.info( + "Enriched preview with %s existing summaries (dataset: %s, document: %s)", + matched_count, + dataset.id, + document.id, + ) + + return preview_output @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8670a71aa3..3c4850ebac 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: source["content"] = segment.get_sign_content() + # Add summary if available + if record.summary: + source["summary"] = record.summary retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index dfb55dcd80..17d82c2118 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]): if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") + if item.get("summary"): + context_str += item["summary"] + "\n" context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) @@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]): page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), files=context_dict.get("files"), + summary=context_dict.get("summary"), ) return source diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 08cf96c1c1..af983f6d87 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery: imports = [ "tasks.async_workflow_tasks", # trigger workers "tasks.trigger_processing_tasks", # async trigger processing + "tasks.generate_summary_index_task", # summary index generation + "tasks.regenerate_summary_index_task", # summary index regeneration ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 1e5ec7d200..ff6578098b 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -39,6 +39,14 @@ dataset_retrieval_model_fields = { "score_threshold_enabled": fields.Boolean, "score_threshold": fields.Float, } + +dataset_summary_index_fields = { + "enable": fields.Boolean, + "model_name": fields.String, + "model_provider_name": fields.String, + "summary_prompt": fields.String, +} + external_retrieval_model_fields = { "top_k": fields.Integer, "score_threshold": fields.Float, @@ -83,6 +91,7 @@ dataset_detail_fields = { "embedding_model_provider": fields.String, "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "summary_index_setting": fields.Nested(dataset_summary_index_fields), "tags": fields.List(fields.Nested(tag_fields)), "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 9be59f7454..35a2a04f3e 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -33,6 +33,11 @@ document_fields = { "hit_count": fields.Integer, "doc_form": fields.String, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + # Whether this document needs summary index generation + "need_summary": fields.Boolean, } document_with_segments_fields = { @@ -60,6 +65,10 @@ document_with_segments_fields = { "completed_segments": fields.Integer, "total_segments": fields.Integer, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + "need_summary": fields.Boolean, # Whether this document needs summary index generation } dataset_and_document_fields = { diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index e70f9fa722..0b54992835 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -58,4 +58,5 @@ hit_testing_record_fields = { "score": fields.Float, "tsne_position": fields.Raw, "files": fields.List(fields.Nested(files_fields)), + "summary": fields.String, # Summary content if retrieved via summary index } diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c81e482f73..e6c3b42f93 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel): segment_position: int | None = None index_node_hash: str | None = None content: str | None = None + summary: str | None = None created_at: int | None = None @field_validator("created_at", mode="before") diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 56d6b68378..2ce9fb154c 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -49,4 +49,5 @@ segment_fields = { "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "attachments": fields.List(fields.Nested(attachment_fields)), + "summary": fields.String, # Summary content for the segment } diff --git a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py new file mode 100644 index 0000000000..3c2e0822e1 --- /dev/null +++ b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py @@ -0,0 +1,107 @@ +"""add summary index feature + +Revision ID: 788d3099ae3a +Revises: 9d77545f524e +Create Date: 2026-01-27 18:15:45.277928 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '788d3099ae3a' +down_revision = '9d77545f524e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey') + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True)) + else: + # MySQL: Use compatible syntax + op.create_table( + 'document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'), + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('need_summary') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('summary_index_setting') + + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.drop_index('document_segment_summaries_status_idx') + batch_op.drop_index('document_segment_summaries_document_id_idx') + batch_op.drop_index('document_segment_summaries_dataset_id_idx') + batch_op.drop_index('document_segment_summaries_chunk_id_idx') + + op.drop_table('document_segment_summaries') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 62f11b8c72..6ab8f372bf 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -72,6 +72,7 @@ class Dataset(Base): keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10")) collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(AdjustedJSON, nullable=True) + summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) @@ -419,6 +420,7 @@ class Document(Base): doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) + need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base): segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DocumentSegmentSummary(Base): + __tablename__ = "document_segment_summaries" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), + sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_summaries_document_id_idx", "document_id"), + sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"), + sa.Index("document_segment_summaries_status_idx", "status"), + ) + + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # corresponds to DocumentSegment.id or parent chunk id + chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + summary_content: Mapped[str] = mapped_column(LongText, nullable=True) + summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) + summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + error: Mapped[str] = mapped_column(LongText, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + def __repr__(self): + return f"" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index be9a0e9279..0b3fcbe4ae 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t from tasks.document_indexing_update_task import document_indexing_update_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -211,6 +212,7 @@ class DatasetService: embedding_model_provider: str | None = None, embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, + summary_index_setting: dict | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -253,6 +255,8 @@ class DatasetService: dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting db.session.add(dataset) db.session.flush() @@ -476,6 +480,11 @@ class DatasetService: if external_retrieval_model: dataset.retrieval_model = external_retrieval_model + # Update summary index setting if provided + summary_index_setting = data.get("summary_index_setting", None) + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting + # Update basic dataset properties dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", dataset.description) @@ -564,6 +573,9 @@ class DatasetService: # update Retrieval model if data.get("retrieval_model"): filtered_data["retrieval_model"] = data["retrieval_model"] + # update summary index setting + if data.get("summary_index_setting"): + filtered_data["summary_index_setting"] = data.get("summary_index_setting") # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") @@ -572,12 +584,27 @@ class DatasetService: db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() + # Reload dataset to get updated values + db.session.refresh(dataset) + # update pipeline knowledge base node data DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) + # If embedding_model changed, also regenerate summary vectors + if action == "update": + regenerate_summary_index_task.delay( + dataset.id, + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + # Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries. + # The new setting will only apply to: + # 1. New documents added after the setting change + # 2. Manual summary generation requests return dataset @@ -616,6 +643,7 @@ class DatasetService: knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue] knowledge_index_node_data["keyword_number"] = dataset.keyword_number + knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting node["data"] = knowledge_index_node_data updated = True except Exception: @@ -854,6 +882,54 @@ class DatasetService: ) filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod + def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool: + """ + Check if summary_index_setting model (model_name or model_provider_name) has changed. + + Args: + dataset: Current dataset object + data: Update data dictionary + + Returns: + bool: True if summary model changed, False otherwise + """ + # Check if summary_index_setting is being updated + if "summary_index_setting" not in data or data.get("summary_index_setting") is None: + return False + + new_summary_setting = data.get("summary_index_setting") + old_summary_setting = dataset.summary_index_setting + + # If new setting is disabled, no need to regenerate + if not new_summary_setting or not new_summary_setting.get("enable"): + return False + + # If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate) + # Note: This task only regenerates existing summaries, not generates new ones + if not old_summary_setting: + return False + + # Compare model_name and model_provider_name + old_model_name = old_summary_setting.get("model_name") + old_model_provider = old_summary_setting.get("model_provider_name") + new_model_name = new_summary_setting.get("model_name") + new_model_provider = new_summary_setting.get("model_provider_name") + + # Check if model changed + if old_model_name != new_model_name or old_model_provider != new_model_provider: + logger.info( + "Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s", + dataset.id, + old_model_provider, + old_model_name, + new_model_provider, + new_model_name, + ) + return True + + return False + @staticmethod def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False @@ -889,6 +965,9 @@ class DatasetService: else: raise ValueError("Invalid index method") dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) else: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: @@ -994,6 +1073,9 @@ class DatasetService: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) session.commit() if action: @@ -1314,6 +1396,50 @@ class DocumentService: upload_file = DocumentService._get_upload_file_for_upload_file_document(document) return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + @staticmethod + def enrich_documents_with_summary_index_status( + documents: Sequence[Document], + dataset: Dataset, + tenant_id: str, + ) -> None: + """ + Enrich documents with summary_index_status based on dataset summary index settings. + + This method calculates and sets the summary_index_status for each document that needs summary. + Documents that don't need summary or when summary index is disabled will have status set to None. + + Args: + documents: List of Document instances to enrich + dataset: Dataset instance containing summary_index_setting + tenant_id: Tenant ID for summary status lookup + """ + # Check if dataset has summary index enabled + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + + # Filter documents that need summary calculation + documents_need_summary = [doc for doc in documents if doc.need_summary is True] + document_ids_need_summary = [str(doc.id) for doc in documents_need_summary] + + # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled) + summary_status_map: dict[str, str | None] = {} + if has_summary_index and document_ids_need_summary: + from services.summary_index_service import SummaryIndexService + + summary_status_map = SummaryIndexService.get_documents_summary_index_status( + document_ids=document_ids_need_summary, + dataset_id=dataset.id, + tenant_id=tenant_id, + ) + + # Add summary_index_status to each document + for document in documents: + if has_summary_index and document.need_summary is True: + # Get status from map, default to None (not queued yet) + document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined] + else: + # Return null if summary index is not enabled or document doesn't need summary + document.summary_index_status = None # type: ignore[attr-defined] + @staticmethod def prepare_document_batch_download_zip( *, @@ -1964,6 +2090,8 @@ class DocumentService: DuplicateDocumentIndexingTaskProxy( dataset.tenant_id, dataset.id, duplicate_document_ids ).delay() + # Note: Summary index generation is triggered in document_indexing_task after indexing completes + # to ensure segments are available. See tasks/document_indexing_task.py except LockNotOwnedError: pass @@ -2268,6 +2396,11 @@ class DocumentService: name: str, batch: str, ): + # Set need_summary based on dataset's summary_index_setting + need_summary = False + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + need_summary = True + document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -2281,6 +2414,7 @@ class DocumentService: created_by=account.id, doc_form=document_form, doc_language=document_language, + need_summary=need_summary, ) doc_metadata = {} if dataset.built_in_field_enabled: @@ -2505,6 +2639,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + summary_index_setting=knowledge_config.summary_index_setting, is_multimodal=knowledge_config.is_multimodal, ) @@ -2686,6 +2821,14 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + # valid summary index setting + summary_index_setting = args["process_rule"].get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable"): + if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]: + raise ValueError("Summary index model name is required") + if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]: + raise ValueError("Summary index model provider name is required") + @staticmethod def batch_update_document_status( dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user @@ -3154,6 +3297,35 @@ class SegmentService: if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # update summary index if summary is provided and has changed + if args.summary is not None: + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + if dataset.indexing_technique == "high_quality": + # Query existing summary from database + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + # Check if summary has changed + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, update it + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -3228,6 +3400,73 @@ class SegmentService: elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # Handle summary index when content changed + if dataset.indexing_technique == "high_quality": + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + if args.summary is None: + # User didn't provide summary, auto-regenerate if segment previously had summary + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + # Segment previously had summary, regenerate it with new content + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info("Auto-regenerated summary for segment %s after content change", segment.id) + except Exception: + logger.exception("Failed to auto-regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails + else: + # User provided summary, check if it has changed + # Manual summary updates are allowed even if summary_index_setting doesn't exist + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, use user-provided summary + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + logger.info("Updated summary for segment %s with user-provided content", segment.id) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails + else: + # Summary hasn't changed, regenerate based on new content + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info( + "Regenerated summary for segment %s after content change (summary unchanged)", + segment.id, + ) + except Exception: + logger.exception("Failed to regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails # update multimodel vector index VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: @@ -3616,6 +3855,39 @@ class SegmentService: ) return result if isinstance(result, DocumentSegment) else None + @classmethod + def get_segments_by_document_and_dataset( + cls, + document_id: str, + dataset_id: str, + status: str | None = None, + enabled: bool | None = None, + ) -> Sequence[DocumentSegment]: + """ + Get segments for a document in a dataset with optional filtering. + + Args: + document_id: Document ID + dataset_id: Dataset ID + status: Optional status filter (e.g., "completed") + enabled: Optional enabled filter (True/False) + + Returns: + Sequence of DocumentSegment instances + """ + query = select(DocumentSegment).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + + if status is not None: + query = query.where(DocumentSegment.status == status) + + if enabled is not None: + query = query.where(DocumentSegment.enabled == enabled) + + return db.session.scalars(query).all() + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 7959734e89..8dc5b93501 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel): data_source: DataSource | None = None process_rule: ProcessRule | None = None retrieval_model: RetrievalModel | None = None + summary_index_setting: dict | None = None doc_form: str = "text_model" doc_language: str = "English" embedding_model: str | None = None @@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel): regenerate_child_chunks: bool = False enabled: bool | None = None attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index cbb0efcc2a..041ae4edba 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel): embedding_model: str = "" keyword_number: int | None = 10 retrieval_model: RetrievalSetting + # add summary index setting + summary_index_setting: dict | None = None @field_validator("embedding_model_provider", mode="before") @classmethod diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c1c6e204fb..be1ce834f6 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -343,6 +343,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() @@ -477,6 +480,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py new file mode 100644 index 0000000000..b8e1f8bc3f --- /dev/null +++ b/api/services/summary_index_service.py @@ -0,0 +1,1432 @@ +"""Summary index service for generating and managing document segment summaries.""" + +import logging +import time +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.orm import Session + +from core.db.session_factory import session_factory +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument + +logger = logging.getLogger(__name__) + + +class SummaryIndexService: + """Service for generating and managing summary indexes.""" + + @staticmethod + def generate_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for a single segment. + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + + Raises: + ValueError: If summary_index_setting is invalid or generation fails + """ + # Reuse the existing generate_summary method from ParagraphIndexProcessor + # Use lazy import to avoid circular import + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + summary_content, usage = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=segment.content, + summary_index_setting=summary_index_setting, + segment_id=segment.id, + ) + + if not summary_content: + raise ValueError("Generated summary is empty") + + return summary_content, usage + + @staticmethod + def create_summary_record( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + status: str = "generating", + ) -> DocumentSegmentSummary: + """ + Create or update a DocumentSegmentSummary record. + If a summary record already exists for this segment, it will be updated instead of creating a new one. + + Args: + segment: DocumentSegment to create summary for + dataset: Dataset containing the segment + summary_content: Generated summary content + status: Summary status (default: "generating") + + Returns: + Created or updated DocumentSegmentSummary instance + """ + with session_factory.create_session() as session: + # Check if summary record already exists + existing_summary = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if existing_summary: + # Update existing record + existing_summary.summary_content = summary_content + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + # Re-enable if it was disabled + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + session.flush() + return existing_summary + else: + # Create new record (enabled by default) + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + status=status, + enabled=True, # Explicitly set enabled to True + ) + session.add(summary_record) + session.flush() + return summary_record + + @staticmethod + def vectorize_summary( + summary_record: DocumentSegmentSummary, + segment: DocumentSegment, + dataset: Dataset, + session: Session | None = None, + ) -> None: + """ + Vectorize summary and store in vector database. + + Args: + summary_record: DocumentSegmentSummary record + segment: Original DocumentSegment + dataset: Dataset containing the segment + session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. + If not provided, creates a new session and commits automatically. + """ + if dataset.indexing_technique != "high_quality": + logger.warning( + "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", + dataset.id, + ) + return + + # Get summary_record_id for later session queries + summary_record_id = summary_record.id + # Save the original session parameter for use in error handling + original_session = session + logger.debug( + "Starting vectorization for segment %s, summary_record_id=%s, using_provided_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + + # Reuse existing index_node_id if available (like segment does), otherwise generate new one + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + # Reuse existing index_node_id (like segment behavior) + summary_index_node_id = old_summary_node_id + logger.debug("Reusing existing index_node_id %s for segment %s", summary_index_node_id, segment.id) + else: + # Generate new index node ID only for new summaries + summary_index_node_id = str(uuid.uuid4()) + logger.debug("Generated new index_node_id %s for segment %s", summary_index_node_id, segment.id) + + # Always regenerate hash (in case summary content changed) + summary_content = summary_record.summary_content + if not summary_content or not summary_content.strip(): + raise ValueError(f"Summary content is empty for segment {segment.id}, cannot vectorize") + summary_hash = helper.generate_text_hash(summary_content) + + # Delete old vector only if we're reusing the same index_node_id (to overwrite) + # If index_node_id changed, the old vector should have been deleted elsewhere + if old_summary_node_id and old_summary_node_id == summary_index_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.", + segment.id, + str(e), + ) + + # Calculate embedding tokens for summary (for logging and statistics) + embedding_tokens = 0 + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) + embedding_tokens = tokens_list[0] if tokens_list else 0 + except Exception as e: + logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) + + # Create document with summary content and metadata + summary_document = Document( + page_content=summary_content, + metadata={ + "doc_id": summary_index_node_id, + "doc_hash": summary_hash, + "dataset_id": dataset.id, + "document_id": segment.document_id, + "original_chunk_id": segment.id, # Key: link to original chunk + "doc_type": DocType.TEXT, + "is_summary": True, # Identifier for summary documents + }, + ) + + # Vectorize and store with retry mechanism for connection errors + max_retries = 3 + retry_delay = 2.0 + + for attempt in range(max_retries): + try: + logger.debug( + "Attempting to vectorize summary for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + vector = Vector(dataset) + # Use duplicate_check=False to ensure re-vectorization even if old vector still exists + # The old vector should have been deleted above, but if deletion failed, + # we still want to re-vectorize (upsert will overwrite) + vector.add_texts([summary_document], duplicate_check=False) + logger.debug( + "Successfully added summary vector to database for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + + # Log embedding token usage + if embedding_tokens > 0: + logger.info( + "Summary embedding for segment %s used %s tokens", + segment.id, + embedding_tokens, + ) + + # Success - update summary record with index node info + # Use provided session if available, otherwise create a new one + use_provided_session = session is not None + if not use_provided_session: + logger.debug("Creating new session for vectorization of segment %s", segment.id) + session_context = session_factory.create_session() + session = session_context.__enter__() + else: + logger.debug("Using provided session for vectorization of segment %s", segment.id) + session_context = None # Don't use context manager for provided session + + # At this point, session is guaranteed to be not None + # Type narrowing: session is definitely not None after the if/else above + if session is None: + raise RuntimeError("Session should not be None at this point") + + try: + # Declare summary_record_in_session variable + summary_record_in_session: DocumentSegmentSummary | None + + # If using provided session, merge the summary_record into it + if use_provided_session: + # Merge the summary_record into the provided session + logger.debug( + "Merging summary_record (id=%s) into provided session for segment %s", + summary_record_id, + segment.id, + ) + summary_record_in_session = session.merge(summary_record) + logger.debug( + "Successfully merged summary_record for segment %s, merged_id=%s", + segment.id, + summary_record_in_session.id, + ) + else: + # Query the summary record in the new session + logger.debug( + "Querying summary_record by id=%s for segment %s in new session", + summary_record_id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + + if not summary_record_in_session: + # Record not found - try to find by chunk_id and dataset_id instead + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if not summary_record_in_session: + # Still not found - create a new one using the parameter data + logger.warning( + "Summary record not found in database for segment %s (id=%s), creating new one. " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + summary_record_in_session = DocumentSegmentSummary( + id=summary_record_id, # Use the same ID if available + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + summary_index_node_id=summary_index_node_id, + summary_index_node_hash=summary_hash, + tokens=embedding_tokens, + status="completed", + enabled=True, + ) + session.add(summary_record_in_session) + logger.info( + "Created new summary record (id=%s) for segment %s after vectorization", + summary_record_id, + segment.id, + ) + else: + # Found by chunk_id - update it + logger.info( + "Found summary record for segment %s by chunk_id " + "(id mismatch: expected %s, found %s). " + "This may indicate the record was created in a different session.", + segment.id, + summary_record_id, + summary_record_in_session.id, + ) + else: + logger.debug( + "Found summary_record (id=%s) for segment %s in new session", + summary_record_id, + segment.id, + ) + + # At this point, summary_record_in_session is guaranteed to be not None + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + + # Update all fields including summary_content + # Always use the summary_content from the parameter (which is the latest from outer session) + # rather than relying on what's in the database, in case outer session hasn't committed yet + summary_record_in_session.summary_index_node_id = summary_index_node_id + summary_record_in_session.summary_index_node_hash = summary_hash + summary_record_in_session.tokens = embedding_tokens # Save embedding tokens + summary_record_in_session.status = "completed" + # Ensure summary_content is preserved (use the latest from summary_record parameter) + # This is critical: use the parameter value, not the database value + summary_record_in_session.summary_content = summary_content + # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + session.add(summary_record_in_session) + + # Only commit if we created the session ourselves + if not use_provided_session: + logger.debug("Committing session for segment %s (self-created session)", segment.id) + session.commit() + logger.debug("Successfully committed session for segment %s", segment.id) + else: + # When using provided session, flush to ensure changes are written to database + # This prevents refresh() from overwriting our changes + logger.debug( + "Flushing session for segment %s (using provided session, caller will commit)", + segment.id, + ) + session.flush() + logger.debug("Successfully flushed session for segment %s", segment.id) + # If using provided session, let the caller handle commit + + logger.info( + "Successfully vectorized summary for segment %s, index_node_id=%s, index_node_hash=%s, " + "tokens=%s, summary_record_id=%s, use_provided_session=%s", + segment.id, + summary_index_node_id, + summary_hash, + embedding_tokens, + summary_record_in_session.id, + use_provided_session, + ) + # Update the original object for consistency + summary_record.summary_index_node_id = summary_index_node_id + summary_record.summary_index_node_hash = summary_hash + summary_record.tokens = embedding_tokens + summary_record.status = "completed" + summary_record.summary_content = summary_content + if summary_record_in_session.updated_at: + summary_record.updated_at = summary_record_in_session.updated_at + finally: + # Only close session if we created it ourselves + if not use_provided_session and session_context: + session_context.__exit__(None, None, None) + # Success, exit function + return + + except (ConnectionError, Exception) as e: + error_str = str(e).lower() + # Check if it's a connection-related error that might be transient + is_connection_error = any( + keyword in error_str + for keyword in [ + "connection", + "disconnected", + "timeout", + "network", + "could not connect", + "server disconnected", + "weaviate", + ] + ) + + if is_connection_error and attempt < max_retries - 1: + # Retry for connection errors + wait_time = retry_delay * (2**attempt) # Exponential backoff + logger.warning( + "Vectorization attempt %s/%s failed for segment %s (connection error): %s. " + "Retrying in %.1f seconds...", + attempt + 1, + max_retries, + segment.id, + str(e), + wait_time, + ) + time.sleep(wait_time) + continue + else: + # Final attempt failed or non-connection error - log and update status + logger.error( + "Failed to vectorize summary for segment %s after %s attempts: %s. " + "summary_record_id=%s, index_node_id=%s, use_provided_session=%s", + segment.id, + attempt + 1, + str(e), + summary_record_id, + summary_index_node_id, + session is not None, + exc_info=True, + ) + # Update error status in session + # Use the original_session saved at function start (the function parameter) + logger.debug( + "Updating error status for segment %s, summary_record_id=%s, has_original_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + # Always create a new session for error handling to avoid issues with closed sessions + # Even if original_session was provided, we create a new one for safety + with session_factory.create_session() as error_session: + # Try to find the record by id first + # Note: Using assignment only (no type annotation) to avoid redeclaration error + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + if not summary_record_in_session: + # Try to find by chunk_id and dataset_id + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(e)}" + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + error_session.add(summary_record_in_session) + error_session.commit() + logger.info( + "Updated error status in new session for segment %s, record_id=%s", + segment.id, + summary_record_in_session.id, + ) + # Update the original object for consistency + summary_record.status = "error" + summary_record.error = summary_record_in_session.error + summary_record.updated_at = summary_record_in_session.updated_at + else: + logger.warning( + "Could not update error status: summary record not found for segment %s (id=%s). " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + raise + + @staticmethod + def batch_create_summary_records( + segments: list[DocumentSegment], + dataset: Dataset, + status: str = "not_started", + ) -> None: + """ + Batch create summary records for segments with specified status. + If a record already exists, update its status. + + Args: + segments: List of DocumentSegment instances + dataset: Dataset containing the segments + status: Initial status for the records (default: "not_started") + """ + segment_ids = [segment.id for segment in segments] + if not segment_ids: + return + + with session_factory.create_session() as session: + # Query existing summary records + existing_summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .all() + ) + existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} + + # Create or update records + for segment in segments: + existing_summary = existing_summary_map.get(segment.id) + if existing_summary: + # Update existing record + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + else: + # Create new record + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=None, # Will be filled later + status=status, + enabled=True, + ) + session.add(summary_record) + + @staticmethod + def update_summary_record_error( + segment: DocumentSegment, + dataset: Dataset, + error: str, + ) -> None: + """ + Update summary record with error status. + + Args: + segment: DocumentSegment + dataset: Dataset containing the segment + error: Error message + """ + with session_factory.create_session() as session: + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + summary_record.status = "error" + summary_record.error = error + session.add(summary_record) + session.commit() + else: + logger.warning("Summary record not found for segment %s when updating error", segment.id) + + @staticmethod + def generate_and_vectorize_summary( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> DocumentSegmentSummary: + """ + Generate summary for a segment and vectorize it. + Assumes summary record already exists (created by batch_create_summary_records). + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Created DocumentSegmentSummary instance + + Raises: + ValueError: If summary generation fails + """ + with session_factory.create_session() as session: + try: + # Get or refresh summary record in this session + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if not summary_record_in_session: + # If not found, create one + logger.warning("Summary record not found for segment %s, creating one", segment.id) + summary_record_in_session = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content="", + status="generating", + enabled=True, + ) + session.add(summary_record_in_session) + session.flush() + + # Update status to "generating" + summary_record_in_session.status = "generating" + summary_record_in_session.error = None # type: ignore[assignment] + session.add(summary_record_in_session) + # Don't flush here - wait until after vectorization succeeds + + # Generate summary (returns summary_content and llm_usage) + summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment( + segment, dataset, summary_index_setting + ) + + # Update summary content + summary_record_in_session.summary_content = summary_content + session.add(summary_record_in_session) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Log LLM usage for summary generation + if llm_usage and llm_usage.total_tokens > 0: + logger.info( + "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)", + segment.id, + llm_usage.total_tokens, + llm_usage.prompt_tokens, + llm_usage.completion_tokens, + ) + + # Vectorize summary (will delete old vector if exists before creating new one) + # Pass the session-managed record to vectorize_summary + # vectorize_summary will update status to "completed" and tokens in its own session + # vectorize_summary will also ensure summary_content is preserved + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record_in_session, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record_in_session) + # Commit the session + # (summary_record_in_session should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully generated and vectorized summary for segment %s", segment.id) + return summary_record_in_session + except Exception as vectorize_error: + # If vectorization fails, update status to error in current session + logger.exception("Failed to vectorize summary for segment %s", segment.id) + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" + session.add(summary_record_in_session) + session.commit() + raise + + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = str(e) + session.add(summary_record_in_session) + session.commit() + raise + + @staticmethod + def generate_summaries_for_document( + dataset: Dataset, + document: DatasetDocument, + summary_index_setting: dict, + segment_ids: list[str] | None = None, + only_parent_chunks: bool = False, + ) -> list[DocumentSegmentSummary]: + """ + Generate summaries for all segments in a document including vectorization. + + Args: + dataset: Dataset containing the document + document: DatasetDocument to generate summaries for + summary_index_setting: Summary index configuration + segment_ids: Optional list of specific segment IDs to process + only_parent_chunks: If True, only process parent chunks (for parent-child mode) + + Returns: + List of created DocumentSegmentSummary instances + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", + dataset.id, + dataset.indexing_technique, + ) + return [] + + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info("Summary index is disabled for dataset %s", dataset.id) + return [] + + # Skip qa_model documents + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + return [] + + logger.info( + "Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s", + document.id, + dataset.id, + len(segment_ids) if segment_ids else "all", + only_parent_chunks, + ) + + with session_factory.create_session() as session: + # Query segments (only enabled segments) + query = session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, # Only generate summaries for enabled segments + ) + + if segment_ids: + query = query.filter(DocumentSegment.id.in_(segment_ids)) + + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return [] + + # Batch create summary records with "not_started" status before processing + # This ensures all records exist upfront, allowing status tracking + SummaryIndexService.batch_create_summary_records( + segments=segments, + dataset=dataset, + status="not_started", + ) + session.commit() # Commit initial records + + summary_records = [] + + for segment in segments: + # For parent-child mode, only process parent chunks + # In parent-child mode, all DocumentSegments are parent chunks, + # so we process all of them. Child chunks are stored in ChildChunk table + # and are not DocumentSegments, so they won't be in the segments list. + # This check is mainly for clarity and future-proofing. + if only_parent_chunks: + # In parent-child mode, all segments in the query are parent chunks + # Child chunks are not DocumentSegments, so they won't appear here + # We can process all segments + pass + + try: + summary_record = SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + summary_records.append(summary_record) + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + SummaryIndexService.update_summary_record_error( + segment=segment, + dataset=dataset, + error=str(e), + ) + # Continue with other segments + continue + + logger.info( + "Completed summary generation for document %s: %s summaries generated and vectorized", + document.id, + len(summary_records), + ) + return summary_records + + @staticmethod + def disable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + disabled_by: str | None = None, + ) -> None: + """ + Disable summary records and remove vectors from vector database for segments. + Unlike delete, this preserves the summary records but marks them as disabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to disable summaries for. If None, disable all. + disabled_by: User ID who disabled the summaries + """ + from libs.datetime_utils import naive_utc_now + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=True, # Only disable enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Disabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Remove from vector database (but keep records) + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + try: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + except Exception as e: + logger.warning("Failed to remove summary vectors: %s", str(e)) + + # Disable summary records (don't delete) + now = naive_utc_now() + for summary in summaries: + summary.enabled = False + summary.disabled_at = now + summary.disabled_by = disabled_by + session.add(summary) + + session.commit() + logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def enable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Enable summary records and re-add vectors to vector database for segments. + + Note: This method enables summaries based on chunk status, not summary_index_setting.enable. + The summary_index_setting.enable flag only controls automatic generation, + not whether existing summaries can be used. + Summary.enabled should always be kept in sync with chunk.enabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to enable summaries for. If None, enable all. + """ + # Only enable summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=False, # Only enable disabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Enabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Re-vectorize and re-add to vector database + enabled_count = 0 + for summary in summaries: + # Get the original segment + segment = ( + session.query(DocumentSegment) + .filter_by( + id=summary.chunk_id, + dataset_id=dataset.id, + ) + .first() + ) + + # Summary.enabled stays in sync with chunk.enabled, + # only enable summary if the associated chunk is enabled. + if not segment or not segment.enabled or segment.status != "completed": + continue + + if not summary.summary_content: + continue + + try: + # Re-vectorize summary (this will update status and tokens in its own session) + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary) + + # Enable summary record + summary.enabled = True + summary.disabled_at = None + summary.disabled_by = None + session.add(summary) + enabled_count += 1 + except Exception: + logger.exception("Failed to re-vectorize summary %s", summary.id) + # Keep it disabled if vectorization fails + continue + + session.commit() + logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id) + + @staticmethod + def delete_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Delete summary records and vectors for segments (used only for actual deletion scenarios). + For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to delete summaries for. If None, delete all. + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + # Delete from vector database + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + + # Delete summary records + for summary in summaries: + session.delete(summary) + + session.commit() + logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def update_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + ) -> DocumentSegmentSummary | None: + """ + Update summary for a segment and re-vectorize it. + + Args: + segment: DocumentSegment to update summary for + dataset: Dataset containing the segment + summary_content: New summary content + + Returns: + Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality + """ + # Only update summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return None + + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + + # Skip qa_model documents + if segment.document and segment.document.doc_form == "qa_model": + return None + + with session_factory.create_session() as session: + try: + # Check if summary_content is empty (whitespace-only strings are considered empty) + if not summary_content or not summary_content.strip(): + # If summary is empty, only delete existing summary vector and record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record: + # Delete old vector if exists + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Delete summary record since summary is empty + session.delete(summary_record) + session.commit() + logger.info("Deleted summary for segment %s (empty content provided)", segment.id) + return None + else: + # No existing summary record, nothing to do + logger.info("No summary record found for segment %s, nothing to delete", segment.id) + return None + + # Find existing summary record + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + # Update existing summary + old_summary_node_id = summary_record.summary_index_node_id + + # Update summary content + summary_record.summary_content = summary_content + summary_record.status = "generating" + summary_record.error = None # type: ignore[assignment] # Clear any previous errors + session.add(summary_record) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Delete old vector if exists (before vectorization) + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # vectorize_summary will also ensure summary_content is preserved + # Note: vectorize_summary may take time due to embedding API calls, but we need to complete it + # to ensure the summary is properly indexed + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record) + # Now commit the session (summary_record should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Don't raise the exception - just log it and return the record with error status + # This allows the segment update to complete even if vectorization fails + summary_record.status = "error" + summary_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + # The caller can check the status if needed + return summary_record + else: + # Create new summary record if doesn't exist + summary_record = SummaryIndexService.create_summary_record( + segment, dataset, summary_content, status="generating" + ) + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # Note: summary_record was created in a different session, + # so we need to merge it into current session + try: + # Merge the record into current session first (since it was created in a different session) + summary_record = session.merge(summary_record) + # Pass the session to vectorize_summary - it will update the merged record + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh to get updated status and tokens from database + session.refresh(summary_record) + # Commit the session to persist the changes + session.commit() + logger.info("Successfully created and vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Merge the record into current session first + error_record = session.merge(summary_record) + error_record.status = "error" + error_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + return error_record + + except Exception as e: + logger.exception("Failed to update summary for segment %s", segment.id) + # Update summary record with error status if it exists + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record: + summary_record.status = "error" + summary_record.error = str(e) + session.add(summary_record) + session.commit() + raise + + @staticmethod + def get_segment_summary(segment_id: str, dataset_id: str) -> DocumentSegmentSummary | None: + """ + Get summary for a single segment. + + Args: + segment_id: Segment ID (chunk_id) + dataset_id: Dataset ID + + Returns: + DocumentSegmentSummary instance if found, None otherwise + """ + with session_factory.create_session() as session: + return ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .first() + ) + + @staticmethod + def get_segments_summaries(segment_ids: list[str], dataset_id: str) -> dict[str, DocumentSegmentSummary]: + """ + Get summaries for multiple segments. + + Args: + segment_ids: List of segment IDs (chunk_ids) + dataset_id: Dataset ID + + Returns: + Dictionary mapping segment_id to DocumentSegmentSummary (only enabled summaries) + """ + if not segment_ids: + return {} + + with session_factory.create_session() as session: + summary_records = ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .all() + ) + + return {summary.chunk_id: summary for summary in summary_records} + + @staticmethod + def get_document_summaries( + document_id: str, dataset_id: str, segment_ids: list[str] | None = None + ) -> list[DocumentSegmentSummary]: + """ + Get all summary records for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + segment_ids: Optional list of segment IDs to filter by + + Returns: + List of DocumentSegmentSummary instances (only enabled summaries) + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter( + DocumentSegmentSummary.document_id == document_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + return query.all() + + @staticmethod + def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: + """ + Get summary_index_status for a single document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + "SUMMARIZING" if there are pending summaries, None otherwise + """ + # Get all segments for this document (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + segment_ids = [seg.id for seg in segments] + + if not segment_ids: + return None + + # Get all summary records for these segments + summaries = SummaryIndexService.get_segments_summaries(segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Check if there are any "not_started" or "generating" status summaries + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + return "SUMMARIZING" if has_pending_summaries else None + + @staticmethod + def get_documents_summary_index_status( + document_ids: list[str], dataset_id: str, tenant_id: str + ) -> dict[str, str | None]: + """ + Get summary_index_status for multiple documents. + + Args: + document_ids: List of document IDs + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + Dictionary mapping document_id to summary_index_status ("SUMMARIZING" or None) + """ + if not document_ids: + return {} + + # Get all segments for these documents (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id, DocumentSegment.document_id) + .where( + DocumentSegment.document_id.in_(document_ids), + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + + # Group segments by document_id + document_segments_map: dict[str, list[str]] = {} + for segment in segments: + doc_id = str(segment.document_id) + if doc_id not in document_segments_map: + document_segments_map[doc_id] = [] + document_segments_map[doc_id].append(segment.id) + + # Get all summary records for these segments + all_segment_ids = [seg.id for seg in segments] + summaries = SummaryIndexService.get_segments_summaries(all_segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Calculate summary_index_status for each document + result: dict[str, str | None] = {} + for doc_id in document_ids: + segment_ids = document_segments_map.get(doc_id, []) + if not segment_ids: + # No segments, status is None (not started) + result[doc_id] = None + continue + + # Check if there are any "not_started" or "generating" status summaries + # Only check enabled=True summaries (already filtered in query) + # If segment has no summary record (summary_status_map.get returns None), + # it means the summary is disabled (enabled=False) or not created yet, ignore it + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + if has_pending_summaries: + # Task is still running (not started or generating) + result[doc_id] = "SUMMARIZING" + else: + # All enabled=True summaries are "completed" or "error", task finished + # Or no enabled=True summaries exist (all disabled) + result[doc_id] = None + + return result + + @staticmethod + def get_document_summary_status_detail( + document_id: str, + dataset_id: str, + ) -> dict[str, Any]: + """ + Get detailed summary status for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + + Returns: + Dictionary containing: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + from services.dataset_service import SegmentService + + # Get all segments for this document + segments = SegmentService.get_segments_by_document_and_dataset( + document_id=document_id, + dataset_id=dataset_id, + status="completed", + enabled=True, + ) + + total_segments = len(segments) + + # Get all summary records for these segments + segment_ids = [segment.id for segment in segments] + summaries = [] + if segment_ids: + summaries = SummaryIndexService.get_document_summaries( + document_id=document_id, + dataset_id=dataset_id, + segment_ids=segment_ids, + ) + + # Create a mapping of chunk_id to summary + summary_map = {summary.chunk_id: summary for summary in summaries} + + # Count statuses + status_counts = { + "completed": 0, + "generating": 0, + "error": 0, + "not_started": 0, + } + + summary_list = [] + for segment in segments: + summary = summary_map.get(segment.id) + if summary: + status = summary.status + status_counts[status] = status_counts.get(status, 0) + 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": summary.status, + "summary_preview": ( + summary.summary_content[:100] + "..." + if summary.summary_content and len(summary.summary_content) > 100 + else summary.summary_content + ), + "error": summary.error, + "created_at": int(summary.created_at.timestamp()) if summary.created_at else None, + "updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None, + } + ) + else: + status_counts["not_started"] += 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": "not_started", + "summary_preview": None, + "error": None, + "created_at": None, + "updated_at": None, + } + ) + + return { + "total_segments": total_segments, + "summary_status": status_counts, + "summaries": summary_list, + } diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 62e6497e9d..2d3d00cd50 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str): ) session.commit() + # Enable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e)) + end_at = time.perf_counter() logger.info( click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 74b939e84d..d388284980 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 86e7cc7160..91ace6be02 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index bcca1bf49f..4214f043e0 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): ).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) segment_ids = [segment.id for segment in segments] segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) session.execute(segment_delete_stmt) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index bfa709502c..764c635d83 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -47,6 +47,7 @@ def delete_segment_from_index_task( doc_form = dataset_document.doc_form # Proceed with index cleanup using the index_node_ids directly + # For actual deletion, we should delete summaries (not just disable them) index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean( dataset, @@ -54,6 +55,7 @@ def delete_segment_from_index_task( with_keywords=True, delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, + delete_summaries=True, # Actually delete summaries when segment is deleted ) if dataset.is_multimodal: # delete segment attachment binding diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 0ce6429a94..bc45171623 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.clean(dataset, [segment.index_node_id]) + # Disable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + disabled_by=segment.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info( click.style( diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 03635902d1..3cc267e821 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + # Disable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + # Get disabled_by from first segment (they should all have the same disabled_by) + disabled_by = segments[0].disabled_by if segments else None + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 3bdff60196..34496e9c6f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService +from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) @@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): indexing_runner.run(documents) end_at = time.perf_counter() logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + + # Trigger summary index generation for completed documents if enabled + # Only generate for high_quality indexing technique and when summary_index_setting is enabled + # Re-query dataset to get latest summary_index_setting (in case it was updated) + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset %s not found after indexing", dataset_id) + return + + if dataset.indexing_technique == "high_quality": + summary_index_setting = dataset.summary_index_setting + if summary_index_setting and summary_index_setting.get("enable"): + # expire all session to get latest document's indexing status + session.expire_all() + # Check each document's indexing status and trigger summary generation if completed + for document_id in document_ids: + # Re-query document to get latest status (IndexingRunner may have updated it) + document = ( + session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + logger.info( + "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + if ( + document.indexing_status == "completed" + and document.doc_form != "qa_model" + and document.need_summary is True + ): + try: + generate_summary_index_task.delay(dataset.id, document_id, None) + logger.info( + "Queued summary index generation task for document %s in dataset %s " + "after indexing completed", + document_id, + dataset.id, + ) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document_id, + ) + # Don't fail the entire indexing process if summary task queuing fails + else: + logger.info( + "Skipping summary generation for document %s: " + "status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + else: + logger.warning("Document %s not found after indexing", document_id) + else: + logger.info( + "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", + dataset.id, + summary_index_setting.get("enable") if summary_index_setting else None, + ) + else: + logger.info( + "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", + dataset.id, + dataset.indexing_technique, + ) except DocumentIsPausedError as ex: logger.info(click.style(str(ex), fg="yellow")) except Exception: diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1f9f21aa7e..41ebb0b076 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str): # save vector index index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) + # Enable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + ) + except Exception as e: + logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 48d3c8e178..d90eb4c39f 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # save vector index index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) + # Enable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py new file mode 100644 index 0000000000..e4273e16b5 --- /dev/null +++ b/api/tasks/generate_summary_index_task.py @@ -0,0 +1,119 @@ +"""Async task for generating summary indexes.""" + +import logging +import time + +import click +from celery import shared_task + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): + """ + Async generate summary index for document segments. + + Args: + dataset_id: Dataset ID + document_id: Document ID + segment_ids: Optional list of specific segment IDs to process. If None, process all segments. + + Usage: + generate_summary_index_task.delay(dataset_id, document_id) + generate_summary_index_task.delay(dataset_id, document_id, segment_ids) + """ + logger.info( + click.style( + f"Start generating summary index for document {document_id} in dataset {dataset_id}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + if not document: + logger.error(click.style(f"Document not found: {document_id}", fg="red")) + return + + # Check if document needs summary + if not document.need_summary: + logger.info( + click.style( + f"Skipping summary generation for document {document_id}: need_summary is False", + fg="cyan", + ) + ) + return + + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary generation for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + # Determine if only parent chunks should be processed + only_parent_chunks = dataset.chunk_structure == "parent_child_index" + + # Generate summaries + summary_records = SummaryIndexService.generate_summaries_for_document( + dataset=dataset, + document=document, + summary_index_setting=summary_index_setting, + segment_ids=segment_ids, + only_parent_chunks=only_parent_chunks, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Summary index generation completed for document {document_id}: " + f"{len(summary_records)} summaries generated, latency: {end_at - start_at}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to generate summary index for document %s", document_id) + # Update document segments with error status if needed + if segment_ids: + error_message = f"Summary generation failed: {str(e)}" + with session_factory.create_session() as session: + session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + ).update( + { + DocumentSegment.error: error_message, + }, + synchronize_session=False, + ) + session.commit() diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py new file mode 100644 index 0000000000..cf8988d13e --- /dev/null +++ b/api/tasks/regenerate_summary_index_task.py @@ -0,0 +1,315 @@ +"""Task for regenerating summary indexes when dataset settings change.""" + +import logging +import time +from collections import defaultdict + +import click +from celery import shared_task +from sqlalchemy import or_, select + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def regenerate_summary_index_task( + dataset_id: str, + regenerate_reason: str = "summary_model_changed", + regenerate_vectors_only: bool = False, +): + """ + Regenerate summary indexes for all documents in a dataset. + + This task is triggered when: + 1. summary_index_setting model changes (regenerate_reason="summary_model_changed") + - Regenerates summary content and vectors for all existing summaries + 2. embedding_model changes (regenerate_reason="embedding_model_changed") + - Only regenerates vectors for existing summaries (keeps summary content) + + Args: + dataset_id: Dataset ID + regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed") + regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content + """ + logger.info( + click.style( + f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + # Only regenerate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary regeneration for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled (only for summary_model change) + # For embedding_model change, we still re-vectorize existing summaries even if setting is disabled + summary_index_setting = dataset.summary_index_setting + if not regenerate_vectors_only: + # For summary_model change, require summary_index_setting to be enabled + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + total_segments_processed = 0 + total_segments_failed = 0 + + if regenerate_vectors_only: + # For embedding_model change: directly query all segments with existing summaries + # Don't require document indexing_status == "completed" + # Include summaries with status "completed" or "error" (if they have content) + segments_with_summaries = ( + session.query(DocumentSegment, DocumentSegmentSummary) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .join( + DatasetDocument, + DocumentSegment.document_id == DatasetDocument.id, + ) + .where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", # Segment must be completed + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content + # Include completed summaries or error summaries (with content) + or_( + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.status == "error", + ), + DatasetDocument.enabled == True, # Document must be enabled + DatasetDocument.archived == False, # Document must not be archived + DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + ) + .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) + .all() + ) + + if not segments_with_summaries: + logger.info( + click.style( + f"No segments with summaries found for re-vectorization in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s segments with summaries for re-vectorization in dataset %s", + len(segments_with_summaries), + dataset_id, + ) + + # Group by document for logging + segments_by_document = defaultdict(list) + for segment, summary_record in segments_with_summaries: + segments_by_document[segment.document_id].append((segment, summary_record)) + + logger.info( + "Segments grouped into %s documents for re-vectorization", + len(segments_by_document), + ) + + for document_id, segment_summary_pairs in segments_by_document.items(): + logger.info( + "Re-vectorizing summaries for %s segments in document %s", + len(segment_summary_pairs), + document_id, + ) + + for segment, summary_record in segment_summary_pairs: + try: + # Delete old vector + if summary_record.summary_index_node_id: + try: + from core.rag.datasource.vdb.vector_factory import Vector + + vector = Vector(dataset) + vector.delete_by_ids([summary_record.summary_index_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize with new embedding model + SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to re-vectorize summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + summary_record.status = "error" + summary_record.error = f"Re-vectorization failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + else: + # For summary_model change: require document indexing_status == "completed" + # Get all documents with completed indexing status + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + + if not dataset_documents: + logger.info( + click.style( + f"No documents found for summary regeneration in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s documents for summary regeneration in dataset %s", + len(dataset_documents), + dataset_id, + ) + + for dataset_document in dataset_documents: + # Skip qa_model documents + if dataset_document.doc_form == "qa_model": + continue + + try: + # Get all segments with existing summaries + segments = ( + session.query(DocumentSegment) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + + if not segments: + continue + + logger.info( + "Regenerating summaries for %s segments in document %s", + len(segments), + dataset_document.id, + ) + + for segment in segments: + summary_record = None + try: + # Get existing summary record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by( + chunk_id=segment.id, + dataset_id=dataset_id, + ) + .first() + ) + + if not summary_record: + logger.warning("Summary record not found for segment %s, skipping", segment.id) + continue + + # Regenerate both summary content and vectors (for summary_model change) + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to regenerate summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + if summary_record: + summary_record.status = "error" + summary_record.error = f"Regeneration failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + except Exception as e: + logger.error( + "Failed to process document %s for summary regeneration: %s", + dataset_document.id, + str(e), + exc_info=True, + ) + continue + + end_at = time.perf_counter() + if regenerate_vectors_only: + logger.info( + click.style( + f"Summary re-vectorization completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + else: + logger.info( + click.style( + f"Summary index regeneration completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + + except Exception: + logger.exception("Regenerate summary index failed for dataset %s", dataset_id) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index c3c255fb17..55259ab527 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -46,6 +46,21 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() + + # Disable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=document.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e)) + index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 0aabe2fc30..08818945e3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -138,6 +138,7 @@ class TestDatasetServiceUpdateDataset: "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task, patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, @@ -147,6 +148,7 @@ class TestDatasetServiceUpdateDataset: "model_manager": mock_model_manager, "get_binding": mock_get_binding, "task": mock_task, + "regenerate_task": mock_regenerate_task, "current_user": mock_current_user, } @@ -549,6 +551,13 @@ class TestDatasetServiceUpdateDataset: # Verify vector index task was triggered mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update") + # Verify regenerate summary index task was triggered (when embedding_model changes) + mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with( + "dataset-123", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + # Verify return value assert result == dataset