mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
refactor: knowledge index node decouples business logic (#32274)
This commit is contained in:
parent
68647391e7
commit
707bf20c29
@ -52,7 +52,6 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
@ -109,7 +108,6 @@ ignore_imports =
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
@ -154,18 +152,12 @@ ignore_imports =
|
||||
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.errors
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
|
||||
@ -19,7 +19,9 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, SystemVariableKey
|
||||
@ -32,6 +34,7 @@ from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
@ -202,6 +205,16 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
return KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
index_processor=IndexProcessor(),
|
||||
summary_index_service=SummaryIndex(),
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
|
||||
252
api/core/rag/index_processor/index_processor.py
Normal file
252
api/core/rag/index_processor/index_processor.py
Normal file
@ -0,0 +1,252 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .index_processor_factory import IndexProcessorFactory
|
||||
from .processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexProcessor:
|
||||
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview = index_processor.format_preview(chunks)
|
||||
data = Preview(
|
||||
chunk_structure=preview["chunk_structure"],
|
||||
total_segments=preview["total_segments"],
|
||||
preview=[],
|
||||
parent_mode=None,
|
||||
qa_preview=[],
|
||||
)
|
||||
if "parent_mode" in preview:
|
||||
data.parent_mode = preview["parent_mode"]
|
||||
|
||||
for item in preview["preview"]:
|
||||
if "content" in item and "child_chunks" in item:
|
||||
data.preview.append(
|
||||
PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None)
|
||||
)
|
||||
elif "question" in item and "answer" in item:
|
||||
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
|
||||
elif "content" in item:
|
||||
data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None))
|
||||
return data
|
||||
|
||||
def index_and_clean(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
dataset_name_value = dataset.name
|
||||
document_name_value = document.name
|
||||
created_at_value = document.created_at
|
||||
if summary_index_setting is None:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
index_node_ids = []
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
if original_document_id:
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
indexing_start_at = time.perf_counter()
|
||||
# delete from vector index
|
||||
if index_node_ids:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if index_node_ids:
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable") is True:
|
||||
document.need_summary = True
|
||||
else:
|
||||
document.need_summary = False
|
||||
session.add(document)
|
||||
# update document segment status
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch,
|
||||
"document_id": document_id,
|
||||
"document_name": document_name_value,
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
if document_id:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
else:
|
||||
document = None
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
if summary_index_setting is None:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
if document:
|
||||
doc_language = document.doc_language
|
||||
indexing_technique = dataset.indexing_technique
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
preview_output = self.format_preview(chunk_structure, chunks)
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
if preview_output.preview is not None:
|
||||
chunk_count = len(preview_output.preview)
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: PreviewItem) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
if preview_item.content is not None:
|
||||
# Set Flask application context in worker thread
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview_item.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item.summary = summary
|
||||
|
||||
else:
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview_item.content if preview_item.content is not None else "",
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output.preview))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output.preview))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item) for preview_item in preview_output.preview
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode, if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise KnowledgeIndexNodeError(error_summary)
|
||||
|
||||
completed_count = sum(1 for item in preview_output.preview if item.summary is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output.preview),
|
||||
)
|
||||
return preview_output
|
||||
0
api/core/rag/summary_index/__init__.py
Normal file
0
api/core/rag/summary_index/__init__.py
Normal file
86
api/core/rag/summary_index/summary_index.py
Normal file
86
api/core/rag/summary_index/summary_index.py
Normal file
@ -0,0 +1,86 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummaryIndex:
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset or dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
if summary_index_setting is None:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
if not document_id:
|
||||
return
|
||||
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
# Skip qa_model documents
|
||||
if document is None or document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return
|
||||
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
|
||||
# Preview mode should process segments that are MISSING completed summaries
|
||||
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
|
||||
|
||||
# If all segments already have completed summaries, nothing to do in preview mode
|
||||
if not pending_segment_ids:
|
||||
return
|
||||
|
||||
max_workers = min(10, len(pending_segment_ids))
|
||||
|
||||
def process_segment(segment_id: str) -> None:
|
||||
"""Process a single segment in a thread with a fresh DB session."""
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
if segment is None:
|
||||
return
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment_id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment_id) for segment_id in pending_segment_ids]
|
||||
concurrent.futures.wait(futures)
|
||||
else:
|
||||
generate_summary_index_task.delay(dataset_id, document_id, None)
|
||||
@ -1,66 +1,66 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
from core.workflow.repositories.index_processor_protocol import IndexProcessorProtocol
|
||||
from core.workflow.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
node_type = NodeType.KNOWLEDGE_INDEX
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
index_processor: IndexProcessorProtocol,
|
||||
summary_index_service: SummaryIndexServiceProtocol,
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
self.index_processor = index_processor
|
||||
self.summary_index_service = summary_index_service
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
|
||||
# get dataset id as string
|
||||
dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id_segment:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||
dataset_id: str = dataset_id_segment.value
|
||||
|
||||
# get document id as string (may be empty when not provided)
|
||||
document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
document_id: str = document_id_segment.value if document_id_segment else ""
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
|
||||
else:
|
||||
is_preview = False
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
|
||||
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
if not chunks:
|
||||
@ -68,52 +68,49 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
|
||||
# index knowledge
|
||||
try:
|
||||
summary_index_setting = node_data.summary_index_setting
|
||||
if is_preview:
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
# Try to get document language if document_id is available
|
||||
doc_language = None
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document and document.doc_language:
|
||||
doc_language = document.doc_language
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
doc_language=doc_language,
|
||||
outputs = self.index_processor.get_preview_output(
|
||||
chunks, dataset_id, document_id, node_data.chunk_structure, summary_index_setting
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=outputs,
|
||||
outputs=outputs.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
|
||||
results = self._invoke_knowledge_index(
|
||||
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
original_document_id=original_document_id_segment.value if original_document_id_segment else "",
|
||||
is_preview=is_preview,
|
||||
batch=batch.value,
|
||||
chunks=chunks,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
logger.warning("Error when running knowledge index node")
|
||||
logger.warning("Error when running knowledge index node", exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@ -123,392 +120,23 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
node_data: KnowledgeIndexNodeData,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
original_document_id: str,
|
||||
is_preview: bool,
|
||||
batch: Any,
|
||||
chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> Any:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
|
||||
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
doc_id_value = document.id
|
||||
ds_id_value = dataset.id
|
||||
dataset_name_value = dataset.name
|
||||
document_name_value = document.name
|
||||
created_at_value = document.created_at
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
if original_document_id:
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
# update document status
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
document.need_summary = True
|
||||
else:
|
||||
document.need_summary = False
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch.value,
|
||||
"document_id": doc_id_value,
|
||||
"document_name": document_name_value,
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info("No segments need summary generation for document %s", document.id)
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment.id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
logger.info(
|
||||
"Successfully generated summary index for %s segments in document %s",
|
||||
len(segments_to_process),
|
||||
document.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document.id)
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(
|
||||
"Queuing summary index generation task for document %s (production mode)",
|
||||
document.id,
|
||||
)
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info("Summary index generation task queued for document %s", document.id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
doc_language: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
doc_language: Optional document language to ensure summary is generated in the correct language
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset.id,
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode, if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise KnowledgeIndexNodeError(error_summary)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output["preview"]),
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset | None = None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
|
||||
matched_count,
|
||||
dataset.id,
|
||||
document.id,
|
||||
)
|
||||
|
||||
return preview_output
|
||||
):
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("document_id is required.")
|
||||
rst = self.index_processor.index_and_clean(
|
||||
dataset_id, document_id, original_document_id, chunks, batch, summary_index_setting
|
||||
)
|
||||
self.summary_index_service.generate_and_vectorize_summary(
|
||||
dataset_id, document_id, is_preview, summary_index_setting
|
||||
)
|
||||
return rst
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
41
api/core/workflow/repositories/index_processor_protocol.py
Normal file
41
api/core/workflow/repositories/index_processor_protocol.py
Normal file
@ -0,0 +1,41 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PreviewItem(BaseModel):
|
||||
content: str | None = Field(None)
|
||||
child_chunks: list[str] | None = Field(None)
|
||||
summary: str | None = Field(None)
|
||||
|
||||
|
||||
class QaPreview(BaseModel):
|
||||
answer: str | None = Field(None)
|
||||
question: str | None = Field(None)
|
||||
|
||||
|
||||
class Preview(BaseModel):
|
||||
chunk_structure: str
|
||||
parent_mode: str | None = Field(None)
|
||||
preview: list[PreviewItem] = Field([])
|
||||
qa_preview: list[QaPreview] = Field([])
|
||||
total_segments: int
|
||||
|
||||
|
||||
class IndexProcessorProtocol(Protocol):
|
||||
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview: ...
|
||||
|
||||
def index_and_clean(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
) -> Preview: ...
|
||||
@ -0,0 +1,7 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class SummaryIndexServiceProtocol(Protocol):
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
): ...
|
||||
@ -0,0 +1,69 @@
|
||||
"""
|
||||
Integration tests for KnowledgeIndexNode.
|
||||
|
||||
This module provides integration tests for KnowledgeIndexNode with real database interactions.
|
||||
|
||||
Note: These tests require database setup and are more complex than unit tests.
|
||||
For now, we focus on unit tests which provide better coverage for the node logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestKnowledgeIndexNodeIntegration:
|
||||
"""
|
||||
Integration test suite for KnowledgeIndexNode.
|
||||
|
||||
Note: Full integration tests require:
|
||||
- Database setup with datasets and documents
|
||||
- Vector store for embeddings
|
||||
- Model providers for indexing and summarization
|
||||
- IndexProcessor and SummaryIndexService implementations
|
||||
|
||||
For now, unit tests provide comprehensive coverage of the node logic.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
|
||||
def test_end_to_end_knowledge_index_preview(self):
|
||||
"""Test end-to-end knowledge index workflow in preview mode."""
|
||||
# TODO: Implement with real database
|
||||
# 1. Create a dataset
|
||||
# 2. Create a document
|
||||
# 3. Prepare chunks
|
||||
# 4. Run KnowledgeIndexNode in preview mode
|
||||
# 5. Verify preview output
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
|
||||
def test_end_to_end_knowledge_index_production(self):
|
||||
"""Test end-to-end knowledge index workflow in production mode."""
|
||||
# TODO: Implement with real database
|
||||
# 1. Create a dataset
|
||||
# 2. Create a document
|
||||
# 3. Prepare chunks
|
||||
# 4. Run KnowledgeIndexNode in production mode
|
||||
# 5. Verify indexing and summary generation
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
|
||||
def test_knowledge_index_with_summary_enabled(self):
|
||||
"""Test knowledge index with summary index setting enabled."""
|
||||
# TODO: Implement with real database
|
||||
# 1. Create a dataset
|
||||
# 2. Create a document
|
||||
# 3. Prepare chunks
|
||||
# 4. Configure summary index setting
|
||||
# 5. Run KnowledgeIndexNode
|
||||
# 6. Verify summaries are generated and indexed
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
|
||||
def test_knowledge_index_parent_child_structure(self):
|
||||
"""Test knowledge index with parent-child chunk structure."""
|
||||
# TODO: Implement with real database
|
||||
# 1. Create a dataset
|
||||
# 2. Create a document
|
||||
# 3. Prepare parent-child chunks
|
||||
# 4. Run KnowledgeIndexNode
|
||||
# 5. Verify parent-child indexing
|
||||
pass
|
||||
@ -0,0 +1,663 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import SystemVariableKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from core.workflow.repositories.index_processor_protocol import IndexProcessorProtocol, Preview, PreviewItem
|
||||
from core.workflow.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables.segments import StringSegment
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_init_params():
|
||||
"""Create mock GraphInitParams."""
|
||||
return GraphInitParams(
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
app_id=str(uuid.uuid4()),
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
graph_config={},
|
||||
user_id=str(uuid.uuid4()),
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor():
|
||||
"""Create mock IndexProcessorProtocol."""
|
||||
mock_processor = Mock(spec=IndexProcessorProtocol)
|
||||
return mock_processor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_summary_index_service():
|
||||
"""Create mock SummaryIndexServiceProtocol."""
|
||||
mock_service = Mock(spec=SummaryIndexServiceProtocol)
|
||||
return mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_node_data():
|
||||
"""Create sample KnowledgeIndexNodeData."""
|
||||
return KnowledgeIndexNodeData(
|
||||
title="Knowledge Index",
|
||||
type="knowledge-index",
|
||||
chunk_structure="general_structure",
|
||||
index_chunk_variable_selector=["start", "chunks"],
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks():
|
||||
"""Create sample chunks data."""
|
||||
return {
|
||||
"general_chunks": ["Chunk 1 content", "Chunk 2 content"],
|
||||
"data_source_info": {"file_id": str(uuid.uuid4())},
|
||||
}
|
||||
|
||||
|
||||
class TestKnowledgeIndexNode:
|
||||
"""
|
||||
Test suite for KnowledgeIndexNode.
|
||||
"""
|
||||
|
||||
def test_node_initialization(
|
||||
self, mock_graph_init_params, mock_graph_runtime_state, mock_index_processor, mock_summary_index_service
|
||||
):
|
||||
"""Test KnowledgeIndexNode initialization."""
|
||||
# Arrange
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"title": "Knowledge Index",
|
||||
"type": "knowledge-index",
|
||||
"chunk_structure": "general_structure",
|
||||
"index_chunk_variable_selector": ["start", "chunks"],
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert node.id == node_id
|
||||
assert node.index_processor == mock_index_processor
|
||||
assert node.summary_index_service == mock_summary_index_service
|
||||
|
||||
def test_run_without_dataset_id(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run raises KnowledgeIndexNodeError when dataset_id is not provided."""
|
||||
# Arrange
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(KnowledgeIndexNodeError, match="Dataset ID is required"):
|
||||
node._run()
|
||||
|
||||
def test_run_without_index_chunk_variable(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run raises KnowledgeIndexNodeError when index chunk variable is not provided."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(KnowledgeIndexNodeError, match="Index chunk variable is required"):
|
||||
node._run()
|
||||
|
||||
def test_run_with_empty_chunks(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run fails when chunks is empty."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
chunks_selector = ["start", "chunks"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(chunks_selector, StringSegment(value=""))
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Chunks is required" in result.error
|
||||
|
||||
def test_run_preview_mode_success(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
sample_chunks,
|
||||
):
|
||||
"""Test _run succeeds in preview mode."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
chunks_selector = ["start", "chunks"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DOCUMENT_ID],
|
||||
StringSegment(value=document_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.INVOKE_FROM],
|
||||
StringSegment(value=InvokeFrom.DEBUGGER),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
|
||||
|
||||
# Mock preview output
|
||||
mock_preview = Preview(
|
||||
chunk_structure="general_structure",
|
||||
preview=[PreviewItem(content="Chunk 1"), PreviewItem(content="Chunk 2")],
|
||||
total_segments=2,
|
||||
)
|
||||
mock_index_processor.get_preview_output.return_value = mock_preview
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert mock_index_processor.get_preview_output.called
|
||||
|
||||
def test_run_production_mode_success(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
sample_chunks,
|
||||
):
|
||||
"""Test _run succeeds in production mode."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
original_document_id = str(uuid.uuid4())
|
||||
batch = "batch_123"
|
||||
chunks_selector = ["start", "chunks"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DOCUMENT_ID],
|
||||
StringSegment(value=document_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID],
|
||||
StringSegment(value=original_document_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.BATCH],
|
||||
StringSegment(value=batch),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.INVOKE_FROM],
|
||||
StringSegment(value=InvokeFrom.SERVICE_API),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
|
||||
|
||||
# Mock index_and_clean output
|
||||
mock_index_processor.index_and_clean.return_value = {"status": "indexed"}
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert mock_summary_index_service.generate_and_vectorize_summary.called
|
||||
assert mock_index_processor.index_and_clean.called
|
||||
|
||||
def test_run_production_mode_without_batch(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
sample_chunks,
|
||||
):
|
||||
"""Test _run fails when batch is not provided in production mode."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
chunks_selector = ["start", "chunks"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DOCUMENT_ID],
|
||||
StringSegment(value=document_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.INVOKE_FROM],
|
||||
StringSegment(value=InvokeFrom.SERVICE_API),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Batch is required" in result.error
|
||||
|
||||
def test_run_with_knowledge_index_node_error(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
sample_chunks,
|
||||
):
|
||||
"""Test _run handles KnowledgeIndexNodeError properly."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
batch = "batch_123"
|
||||
chunks_selector = ["start", "chunks"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DOCUMENT_ID],
|
||||
StringSegment(value=document_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.BATCH],
|
||||
StringSegment(value=batch),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.INVOKE_FROM],
|
||||
StringSegment(value=InvokeFrom.SERVICE_API),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
|
||||
|
||||
# Mock to raise KnowledgeIndexNodeError
|
||||
mock_index_processor.index_and_clean.side_effect = KnowledgeIndexNodeError("Indexing failed")
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Indexing failed" in result.error
|
||||
assert result.error_type == "KnowledgeIndexNodeError"
|
||||
|
||||
def test_run_with_generic_exception(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
sample_chunks,
|
||||
):
|
||||
"""Test _run handles generic exceptions properly."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
batch = "batch_123"
|
||||
chunks_selector = ["start", "chunks"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DATASET_ID],
|
||||
StringSegment(value=dataset_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.DOCUMENT_ID],
|
||||
StringSegment(value=document_id),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.BATCH],
|
||||
StringSegment(value=batch),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(
|
||||
["sys", SystemVariableKey.INVOKE_FROM],
|
||||
StringSegment(value=InvokeFrom.SERVICE_API),
|
||||
)
|
||||
mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks)
|
||||
|
||||
# Mock to raise generic exception
|
||||
mock_index_processor.index_and_clean.side_effect = Exception("Unexpected error")
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Unexpected error" in result.error
|
||||
assert result.error_type == "Exception"
|
||||
|
||||
def test_invoke_knowledge_index(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
):
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
original_document_id = str(uuid.uuid4())
|
||||
batch = "batch_123"
|
||||
chunks = {"general_chunks": ["content"]}
|
||||
|
||||
mock_index_processor.index_and_clean.return_value = {"status": "indexed"}
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._invoke_knowledge_index(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
original_document_id=original_document_id,
|
||||
is_preview=False,
|
||||
batch=batch,
|
||||
chunks=chunks,
|
||||
summary_index_setting=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_summary_index_service.generate_and_vectorize_summary.called
|
||||
assert mock_index_processor.index_and_clean.called
|
||||
assert result == {"status": "indexed"}
|
||||
|
||||
def test_version_method(self):
|
||||
"""Test version class method."""
|
||||
# Act
|
||||
version = KnowledgeIndexNode.version()
|
||||
|
||||
# Assert
|
||||
assert version == "1"
|
||||
|
||||
def test_get_streaming_template(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test get_streaming_template method."""
|
||||
# Arrange
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
template = node.get_streaming_template()
|
||||
|
||||
# Assert
|
||||
assert template is not None
|
||||
assert template.segments == []
|
||||
|
||||
|
||||
class TestInvokeKnowledgeIndex:
|
||||
def test_invoke_with_summary_index_setting(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_index_processor,
|
||||
mock_summary_index_service,
|
||||
sample_node_data,
|
||||
):
|
||||
# Arrange
|
||||
dataset_id = str(uuid.uuid4())
|
||||
document_id = str(uuid.uuid4())
|
||||
original_document_id = str(uuid.uuid4())
|
||||
batch = "batch_123"
|
||||
chunks = {"general_chunks": ["content"]}
|
||||
summary_setting = {"enabled": True}
|
||||
|
||||
mock_index_processor.index_and_clean.return_value = {"status": "indexed"}
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
index_processor=mock_index_processor,
|
||||
summary_index_service=mock_summary_index_service,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._invoke_knowledge_index(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
original_document_id=original_document_id,
|
||||
is_preview=False,
|
||||
batch=batch,
|
||||
chunks=chunks,
|
||||
summary_index_setting=summary_setting,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_summary_index_service.generate_and_vectorize_summary.assert_called_once_with(
|
||||
dataset_id, document_id, False, summary_setting
|
||||
)
|
||||
mock_index_processor.index_and_clean.assert_called_once_with(
|
||||
dataset_id, document_id, original_document_id, chunks, batch, summary_setting
|
||||
)
|
||||
assert result == {"status": "indexed"}
|
||||
Loading…
Reference in New Issue
Block a user