refactor: knowledge index node decouples business logic (#32274)

This commit is contained in:
wangxiaolei 2026-03-02 17:54:33 +08:00 committed by GitHub
parent 68647391e7
commit 707bf20c29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1196 additions and 445 deletions

View File

@ -52,7 +52,6 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
@ -109,7 +108,6 @@ ignore_imports =
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
@ -154,18 +152,12 @@ ignore_imports =
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
core.workflow.nodes.llm.file_saver -> core.tools.signature
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
core.workflow.nodes.tool.tool_node -> core.tools.errors
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database

View File

@ -19,7 +19,9 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, SystemVariableKey
@ -32,6 +34,7 @@ from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
@ -202,6 +205,16 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
index_processor=IndexProcessor(),
summary_index_service=SummaryIndex(),
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)

View File

@ -0,0 +1,252 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory
from .processor.paragraph_index_processor import ParagraphIndexProcessor
logger = logging.getLogger(__name__)
class IndexProcessor:
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview = index_processor.format_preview(chunks)
data = Preview(
chunk_structure=preview["chunk_structure"],
total_segments=preview["total_segments"],
preview=[],
parent_mode=None,
qa_preview=[],
)
if "parent_mode" in preview:
data.parent_mode = preview["parent_mode"]
for item in preview["preview"]:
if "content" in item and "child_chunks" in item:
data.preview.append(
PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None)
)
elif "question" in item and "answer" in item:
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
elif "content" in item:
data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None))
return data
def index_and_clean(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
dataset_name_value = dataset.name
document_name_value = document.name
created_at_value = document.created_at
if summary_index_setting is None:
summary_index_setting = dataset.summary_index_setting
index_node_ids = []
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
if original_document_id:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
indexing_start_at = time.perf_counter()
# delete from vector index
if index_node_ids:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
with session_factory.create_session() as session, session.begin():
if index_node_ids:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
session.execute(segment_delete_stmt)
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
with session_factory.create_session() as session, session.begin():
document.indexing_latency = indexing_end_at - indexing_start_at
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
.scalar()
) or 0
# Update need_summary based on dataset's summary_index_setting
if summary_index_setting and summary_index_setting.get("enable") is True:
document.need_summary = True
else:
document.need_summary = False
session.add(document)
# update document segment status
session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
return {
"dataset_id": dataset_id,
"dataset_name": dataset_name_value,
"batch": batch,
"document_id": document_id,
"document_name": document_name_value,
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
) -> Preview:
doc_language = None
with session_factory.create_session() as session:
if document_id:
document = session.query(Document).filter_by(id=document_id).first()
else:
document = None
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
if summary_index_setting is None:
summary_index_setting = dataset.summary_index_setting
if document:
doc_language = document.doc_language
indexing_technique = dataset.indexing_technique
tenant_id = dataset.tenant_id
preview_output = self.format_preview(chunk_structure, chunks)
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
if preview_output.preview is not None:
chunk_count = len(preview_output.preview)
logger.info(
"Generating summaries for %s chunks in preview mode (dataset: %s)",
chunk_count,
dataset_id,
)
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: PreviewItem) -> None:
"""Generate summary for a single chunk."""
if flask_app:
with flask_app.app_context():
if preview_item.content is not None:
# Set Flask application context in worker thread
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview_item.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item.summary = summary
else:
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview_item.content if preview_item.content is not None else "",
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output.preview))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output.preview))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item) for preview_item in preview_output.preview
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode, if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise KnowledgeIndexNodeError(error_summary)
completed_count = sum(1 for item in preview_output.preview if item.summary is not None)
logger.info(
"Completed summary generation for preview chunks: %s/%s succeeded",
completed_count,
len(preview_output.preview),
)
return preview_output

View File

View File

@ -0,0 +1,86 @@
import concurrent.futures
import logging
from core.db.session_factory import session_factory
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
class SummaryIndex:
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
) -> None:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset or dataset.indexing_technique != "high_quality":
return
if summary_index_setting is None:
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
if not document_id:
return
document = session.query(Document).filter_by(id=document_id).first()
# Skip qa_model documents
if document is None or document.doc_form == "qa_model":
return
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset_id,
document_id=document_id,
status="completed",
enabled=True,
)
segments = query.all()
segment_ids = [segment.id for segment in segments]
if not segment_ids:
return
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.status == "completed",
)
.all()
)
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
# Preview mode should process segments that are MISSING completed summaries
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
# If all segments already have completed summaries, nothing to do in preview mode
if not pending_segment_ids:
return
max_workers = min(10, len(pending_segment_ids))
def process_segment(segment_id: str) -> None:
"""Process a single segment in a thread with a fresh DB session."""
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
if segment is None:
return
try:
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
except Exception:
logger.exception(
"Failed to generate summary for segment %s",
segment_id,
)
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_segment, segment_id) for segment_id in pending_segment_ids]
concurrent.futures.wait(futures)
else:
generate_summary_index_task.delay(dataset_id, document_id, None)

View File

@ -1,66 +1,66 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import func, select
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
from core.workflow.repositories.index_processor_protocol import IndexProcessorProtocol
from core.workflow.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
index_processor: IndexProcessorProtocol,
summary_index_service: SummaryIndexServiceProtocol,
) -> None:
super().__init__(id, config, graph_init_params, graph_runtime_state)
self.index_processor = index_processor
self.summary_index_service = summary_index_service
def _run(self) -> NodeRunResult: # type: ignore
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
# get dataset id as string
dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id_segment:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
dataset_id: str = dataset_id_segment.value
# get document id as string (may be empty when not provided)
document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
document_id: str = document_id_segment.value if document_id_segment else ""
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
else:
is_preview = False
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
@ -68,52 +68,49 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# index knowledge
try:
summary_index_setting = node_data.summary_index_setting
if is_preview:
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
# Try to get document language if document_id is available
doc_language = None
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document and document.doc_language:
doc_language = document.doc_language
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
outputs = self.index_processor.get_preview_output(
chunks, dataset_id, document_id, node_data.chunk_structure, summary_index_setting
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=outputs,
outputs=outputs.model_dump(exclude_none=True),
)
original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
dataset_id=dataset_id,
document_id=document_id,
original_document_id=original_document_id_segment.value if original_document_id_segment else "",
is_preview=is_preview,
batch=batch.value,
chunks=chunks,
summary_index_setting=summary_index_setting,
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")
logger.warning("Error when running knowledge index node", exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
logger.error(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@ -123,392 +120,23 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def _invoke_knowledge_index(
self,
dataset: Dataset,
node_data: KnowledgeIndexNodeData,
dataset_id: str,
document_id: str,
original_document_id: str,
is_preview: bool,
batch: Any,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
document = db.session.query(Document).filter_by(id=document_id.value).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
doc_id_value = document.id
ds_id_value = dataset.id
dataset_name_value = dataset.name
document_name_value = document.name
created_at_value = document.created_at
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
if original_document_id:
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
document.indexing_latency = indexing_end_at - indexing_start_at
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == doc_id_value,
DocumentSegment.dataset_id == ds_id_value,
)
.scalar()
)
# Update need_summary based on dataset's summary_index_setting
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
document.need_summary = True
else:
document.need_summary = False
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == doc_id_value,
DocumentSegment.dataset_id == ds_id_value,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
db.session.commit()
# Generate summary index if enabled
self._handle_summary_index_generation(dataset, document, variable_pool)
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
"batch": batch.value,
"document_id": doc_id_value,
"document_name": document_name_value,
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
def _handle_summary_index_generation(
self,
dataset: Dataset,
document: Document,
variable_pool: VariablePool,
) -> None:
"""
Handle summary index generation based on mode (debug/preview or production).
Args:
dataset: Dataset containing the document
document: Document to generate summaries for
variable_pool: Variable pool to check invoke_from
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
# Skip qa_model documents
if document.doc_form == "qa_model":
return
# Determine if in preview/debug mode
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
if is_preview:
try:
# Query segments that need summary generation
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
segments = query.all()
if not segments:
logger.info("No segments found for document %s", document.id)
return
# Filter segments based on mode
segments_to_process = []
for segment in segments:
# Skip if summary already exists
existing_summary = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
.first()
)
if existing_summary:
continue
# For parent-child mode, all segments are parent chunks, so process all
segments_to_process.append(segment)
if not segments_to_process:
logger.info("No segments need summary generation for document %s", document.id)
return
# Use ThreadPoolExecutor for concurrent generation
flask_app = current_app._get_current_object() # type: ignore
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
def process_segment(segment: DocumentSegment) -> None:
"""Process a single segment in a thread with Flask app context."""
with flask_app.app_context():
try:
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
except Exception:
logger.exception(
"Failed to generate summary for segment %s",
segment.id,
)
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
# Wait for all tasks to complete
concurrent.futures.wait(futures)
logger.info(
"Successfully generated summary index for %s segments in document %s",
len(segments_to_process),
document.id,
)
except Exception:
logger.exception("Failed to generate summary index for document %s", document.id)
# Don't fail the entire indexing process if summary generation fails
else:
# Production mode: asynchronous generation
logger.info(
"Queuing summary index generation task for document %s (production mode)",
document.id,
)
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info("Summary index generation task queued for document %s", document.id)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document.id,
)
# Don't fail the entire indexing process if task queuing fails
def _get_preview_output_with_summaries(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
doc_language: str | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
This method generates summaries on-the-fly without saving to database.
Args:
chunk_structure: Chunk structure type
chunks: Chunks to generate preview for
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
doc_language: Optional document language to ensure summary is generated in the correct language
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# Check if summary index is enabled
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
# Generate summaries for chunks
if "preview" in preview_output and isinstance(preview_output["preview"], list):
chunk_count = len(preview_output["preview"])
logger.info(
"Generating summaries for %s chunks in preview mode (dataset: %s)",
chunk_count,
dataset.id,
)
# Use ParagraphIndexProcessor's generate_summary method
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get Flask app for application context in worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: dict) -> None:
"""Generate summary for a single chunk."""
if "content" in preview_item:
# Set Flask application context in worker thread
if flask_app:
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item)
for preview_item in preview_output["preview"]
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode, if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise KnowledgeIndexNodeError(error_summary)
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
logger.info(
"Completed summary generation for preview chunks: %s/%s succeeded",
completed_count,
len(preview_output["preview"]),
)
return preview_output
def _get_preview_output(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset | None = None,
variable_pool: VariablePool | None = None,
) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# If dataset is provided, try to enrich preview with summaries
if dataset and variable_pool:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document:
# Query summaries for this document
summaries = (
db.session.query(DocumentSegmentSummary)
.filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
.all()
)
if summaries:
# Create a map of segment content to summary for matching
# Use content matching as chunks in preview might not be indexed yet
summary_by_content = {}
for summary in summaries:
segment = (
db.session.query(DocumentSegment)
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
.first()
)
if segment:
# Normalize content for matching (strip whitespace)
normalized_content = segment.content.strip()
summary_by_content[normalized_content] = summary.summary_content
# Enrich preview with summaries by content matching
if "preview" in preview_output and isinstance(preview_output["preview"], list):
matched_count = 0
for preview_item in preview_output["preview"]:
if "content" in preview_item:
# Normalize content for matching
normalized_chunk_content = preview_item["content"].strip()
if normalized_chunk_content in summary_by_content:
preview_item["summary"] = summary_by_content[normalized_chunk_content]
matched_count += 1
if matched_count > 0:
logger.info(
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
matched_count,
dataset.id,
document.id,
)
return preview_output
):
if not document_id:
raise KnowledgeIndexNodeError("document_id is required.")
rst = self.index_processor.index_and_clean(
dataset_id, document_id, original_document_id, chunks, batch, summary_index_setting
)
self.summary_index_service.generate_and_vectorize_summary(
dataset_id, document_id, is_preview, summary_index_setting
)
return rst
@classmethod
def version(cls) -> str:

View File

@ -0,0 +1,41 @@
from collections.abc import Mapping
from typing import Any, Protocol
from pydantic import BaseModel, Field
class PreviewItem(BaseModel):
content: str | None = Field(None)
child_chunks: list[str] | None = Field(None)
summary: str | None = Field(None)
class QaPreview(BaseModel):
answer: str | None = Field(None)
question: str | None = Field(None)
class Preview(BaseModel):
chunk_structure: str
parent_mode: str | None = Field(None)
preview: list[PreviewItem] = Field([])
qa_preview: list[QaPreview] = Field([])
total_segments: int
class IndexProcessorProtocol(Protocol):
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview: ...
def index_and_clean(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
) -> dict[str, Any]: ...
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
) -> Preview: ...

View File

@ -0,0 +1,7 @@
from typing import Protocol
class SummaryIndexServiceProtocol(Protocol):
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
): ...

View File

@ -0,0 +1,69 @@
"""
Integration tests for KnowledgeIndexNode.
This module provides integration tests for KnowledgeIndexNode with real database interactions.
Note: These tests require database setup and are more complex than unit tests.
For now, we focus on unit tests which provide better coverage for the node logic.
"""
import pytest
class TestKnowledgeIndexNodeIntegration:
"""
Integration test suite for KnowledgeIndexNode.
Note: Full integration tests require:
- Database setup with datasets and documents
- Vector store for embeddings
- Model providers for indexing and summarization
- IndexProcessor and SummaryIndexService implementations
For now, unit tests provide comprehensive coverage of the node logic.
"""
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
def test_end_to_end_knowledge_index_preview(self):
"""Test end-to-end knowledge index workflow in preview mode."""
# TODO: Implement with real database
# 1. Create a dataset
# 2. Create a document
# 3. Prepare chunks
# 4. Run KnowledgeIndexNode in preview mode
# 5. Verify preview output
pass
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
def test_end_to_end_knowledge_index_production(self):
"""Test end-to-end knowledge index workflow in production mode."""
# TODO: Implement with real database
# 1. Create a dataset
# 2. Create a document
# 3. Prepare chunks
# 4. Run KnowledgeIndexNode in production mode
# 5. Verify indexing and summary generation
pass
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
def test_knowledge_index_with_summary_enabled(self):
"""Test knowledge index with summary index setting enabled."""
# TODO: Implement with real database
# 1. Create a dataset
# 2. Create a document
# 3. Prepare chunks
# 4. Configure summary index setting
# 5. Run KnowledgeIndexNode
# 6. Verify summaries are generated and indexed
pass
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
def test_knowledge_index_parent_child_structure(self):
"""Test knowledge index with parent-child chunk structure."""
# TODO: Implement with real database
# 1. Create a dataset
# 2. Create a document
# 3. Prepare parent-child chunks
# 4. Run KnowledgeIndexNode
# 5. Verify parent-child indexing
pass

View File

@ -0,0 +1,663 @@
import time
import uuid
from unittest.mock import Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import SystemVariableKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from core.workflow.repositories.index_processor_protocol import IndexProcessorProtocol, Preview, PreviewItem
from core.workflow.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.segments import StringSegment
from models.enums import UserFrom
@pytest.fixture
def mock_graph_init_params():
"""Create mock GraphInitParams."""
return GraphInitParams(
tenant_id=str(uuid.uuid4()),
app_id=str(uuid.uuid4()),
workflow_id=str(uuid.uuid4()),
graph_config={},
user_id=str(uuid.uuid4()),
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@pytest.fixture
def mock_index_processor():
"""Create mock IndexProcessorProtocol."""
mock_processor = Mock(spec=IndexProcessorProtocol)
return mock_processor
@pytest.fixture
def mock_summary_index_service():
"""Create mock SummaryIndexServiceProtocol."""
mock_service = Mock(spec=SummaryIndexServiceProtocol)
return mock_service
@pytest.fixture
def sample_node_data():
"""Create sample KnowledgeIndexNodeData."""
return KnowledgeIndexNodeData(
title="Knowledge Index",
type="knowledge-index",
chunk_structure="general_structure",
index_chunk_variable_selector=["start", "chunks"],
indexing_technique="high_quality",
summary_index_setting=None,
)
@pytest.fixture
def sample_chunks():
"""Create sample chunks data."""
return {
"general_chunks": ["Chunk 1 content", "Chunk 2 content"],
"data_source_info": {"file_id": str(uuid.uuid4())},
}
class TestKnowledgeIndexNode:
"""
Test suite for KnowledgeIndexNode.
"""
def test_node_initialization(
self, mock_graph_init_params, mock_graph_runtime_state, mock_index_processor, mock_summary_index_service
):
"""Test KnowledgeIndexNode initialization."""
# Arrange
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": {
"title": "Knowledge Index",
"type": "knowledge-index",
"chunk_structure": "general_structure",
"index_chunk_variable_selector": ["start", "chunks"],
},
}
# Act
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Assert
assert node.id == node_id
assert node.index_processor == mock_index_processor
assert node.summary_index_service == mock_summary_index_service
def test_run_without_dataset_id(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
):
"""Test _run raises KnowledgeIndexNodeError when dataset_id is not provided."""
# Arrange
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act & Assert
with pytest.raises(KnowledgeIndexNodeError, match="Dataset ID is required"):
node._run()
def test_run_without_index_chunk_variable(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
):
"""Test _run raises KnowledgeIndexNodeError when index chunk variable is not provided."""
# Arrange
dataset_id = str(uuid.uuid4())
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act & Assert
with pytest.raises(KnowledgeIndexNodeError, match="Index chunk variable is required"):
node._run()
def test_run_with_empty_chunks(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
):
"""Test _run fails when chunks is empty."""
# Arrange
dataset_id = str(uuid.uuid4())
chunks_selector = ["start", "chunks"]
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
mock_graph_runtime_state.variable_pool.add(chunks_selector, StringSegment(value=""))
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Chunks is required" in result.error
def test_run_preview_mode_success(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
sample_chunks,
):
"""Test _run succeeds in preview mode."""
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
chunks_selector = ["start", "chunks"]
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DOCUMENT_ID],
StringSegment(value=document_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.INVOKE_FROM],
StringSegment(value=InvokeFrom.DEBUGGER),
)
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
# Mock preview output
mock_preview = Preview(
chunk_structure="general_structure",
preview=[PreviewItem(content="Chunk 1"), PreviewItem(content="Chunk 2")],
total_segments=2,
)
mock_index_processor.get_preview_output.return_value = mock_preview
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert mock_index_processor.get_preview_output.called
def test_run_production_mode_success(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
sample_chunks,
):
"""Test _run succeeds in production mode."""
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
original_document_id = str(uuid.uuid4())
batch = "batch_123"
chunks_selector = ["start", "chunks"]
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DOCUMENT_ID],
StringSegment(value=document_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID],
StringSegment(value=original_document_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.BATCH],
StringSegment(value=batch),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.INVOKE_FROM],
StringSegment(value=InvokeFrom.SERVICE_API),
)
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
# Mock index_and_clean output
mock_index_processor.index_and_clean.return_value = {"status": "indexed"}
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert mock_summary_index_service.generate_and_vectorize_summary.called
assert mock_index_processor.index_and_clean.called
def test_run_production_mode_without_batch(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
sample_chunks,
):
"""Test _run fails when batch is not provided in production mode."""
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
chunks_selector = ["start", "chunks"]
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DOCUMENT_ID],
StringSegment(value=document_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.INVOKE_FROM],
StringSegment(value=InvokeFrom.SERVICE_API),
)
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Batch is required" in result.error
def test_run_with_knowledge_index_node_error(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
sample_chunks,
):
"""Test _run handles KnowledgeIndexNodeError properly."""
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
batch = "batch_123"
chunks_selector = ["start", "chunks"]
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DOCUMENT_ID],
StringSegment(value=document_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.BATCH],
StringSegment(value=batch),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.INVOKE_FROM],
StringSegment(value=InvokeFrom.SERVICE_API),
)
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
# Mock to raise KnowledgeIndexNodeError
mock_index_processor.index_and_clean.side_effect = KnowledgeIndexNodeError("Indexing failed")
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Indexing failed" in result.error
assert result.error_type == "KnowledgeIndexNodeError"
def test_run_with_generic_exception(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
sample_chunks,
):
"""Test _run handles generic exceptions properly."""
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
batch = "batch_123"
chunks_selector = ["start", "chunks"]
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DATASET_ID],
StringSegment(value=dataset_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.DOCUMENT_ID],
StringSegment(value=document_id),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.BATCH],
StringSegment(value=batch),
)
mock_graph_runtime_state.variable_pool.add(
["sys", SystemVariableKey.INVOKE_FROM],
StringSegment(value=InvokeFrom.SERVICE_API),
)
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
# Mock to raise generic exception
mock_index_processor.index_and_clean.side_effect = Exception("Unexpected error")
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Unexpected error" in result.error
assert result.error_type == "Exception"
def test_invoke_knowledge_index(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
):
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
original_document_id = str(uuid.uuid4())
batch = "batch_123"
chunks = {"general_chunks": ["content"]}
mock_index_processor.index_and_clean.return_value = {"status": "indexed"}
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._invoke_knowledge_index(
dataset_id=dataset_id,
document_id=document_id,
original_document_id=original_document_id,
is_preview=False,
batch=batch,
chunks=chunks,
summary_index_setting=None,
)
# Assert
assert mock_summary_index_service.generate_and_vectorize_summary.called
assert mock_index_processor.index_and_clean.called
assert result == {"status": "indexed"}
def test_version_method(self):
"""Test version class method."""
# Act
version = KnowledgeIndexNode.version()
# Assert
assert version == "1"
def test_get_streaming_template(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
):
"""Test get_streaming_template method."""
# Arrange
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
template = node.get_streaming_template()
# Assert
assert template is not None
assert template.segments == []
class TestInvokeKnowledgeIndex:
def test_invoke_with_summary_index_setting(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_index_processor,
mock_summary_index_service,
sample_node_data,
):
# Arrange
dataset_id = str(uuid.uuid4())
document_id = str(uuid.uuid4())
original_document_id = str(uuid.uuid4())
batch = "batch_123"
chunks = {"general_chunks": ["content"]}
summary_setting = {"enabled": True}
mock_index_processor.index_and_clean.return_value = {"status": "indexed"}
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeIndexNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
index_processor=mock_index_processor,
summary_index_service=mock_summary_index_service,
)
# Act
result = node._invoke_knowledge_index(
dataset_id=dataset_id,
document_id=document_id,
original_document_id=original_document_id,
is_preview=False,
batch=batch,
chunks=chunks,
summary_index_setting=summary_setting,
)
# Assert
mock_summary_index_service.generate_and_vectorize_summary.assert_called_once_with(
dataset_id, document_id, False, summary_setting
)
mock_index_processor.index_and_clean.assert_called_once_with(
dataset_id, document_id, original_document_id, chunks, batch, summary_setting
)
assert result == {"status": "indexed"}