mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/hitl-frontend
This commit is contained in:
commit
3170e12b0a
|
|
@ -47,13 +47,9 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Run Mypy Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
run: make type-check
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
|
|||
10
Makefile
10
Makefile
|
|
@ -68,9 +68,11 @@ lint:
|
|||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type check with basedpyright..."
|
||||
@uv run --directory api --dev basedpyright
|
||||
@echo "✅ Type check complete"
|
||||
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@cd api && uv run ty check
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
|
|
@ -130,7 +132,7 @@ help:
|
|||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
|
|
|||
|
|
@ -617,6 +617,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
|||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
|
|
@ -716,4 +717,3 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
|||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,9 @@ ignore_imports =
|
|||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
|
|
@ -300,6 +303,58 @@ ignore_imports =
|
|||
core.workflow.nodes.agent.agent_node -> services
|
||||
core.workflow.nodes.tool.tool_node -> services
|
||||
|
||||
[importlinter:contract:model-runtime-no-internal-imports]
|
||||
name = Model Runtime Internal Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.model_runtime
|
||||
forbidden_modules =
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
services
|
||||
tasks
|
||||
core.agent
|
||||
core.app
|
||||
core.base
|
||||
core.callback_handler
|
||||
core.datasource
|
||||
core.db
|
||||
core.entities
|
||||
core.errors
|
||||
core.extension
|
||||
core.external_data_tool
|
||||
core.file
|
||||
core.helper
|
||||
core.hosting_configuration
|
||||
core.indexing_runner
|
||||
core.llm_generator
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
core.schemas
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.model_runtime.model_providers.__base.ai_model -> configs
|
||||
core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
core.model_runtime.model_providers.__base.large_language_model -> configs
|
||||
core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
|
||||
core.model_runtime.model_providers.model_provider_factory -> configs
|
||||
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
|
|
|||
|
|
@ -243,6 +243,11 @@ class PluginConfig(BaseSettings):
|
|||
default=15728640 * 12,
|
||||
)
|
||||
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
|
||||
description="TTL in seconds for caching plugin model schemas in Redis",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
|
@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
|||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
|
||||
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
|
||||
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
|
|
|||
|
|
@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource):
|
|||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
content = {
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
}
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
content={
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
},
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
|
@ -288,7 +289,14 @@ class DatasetListApi(Resource):
|
|||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict())
|
||||
# Convert query parameters to dict, handling list parameters correctly
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
|
||||
if "ids" in request.args:
|
||||
query_params["ids"] = request.args.getlist("ids")
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = ConsoleDatasetListQuery.model_validate(query_params)
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
if query.ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog
|
|||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
|
|
@ -125,6 +130,7 @@ register_schema_models(
|
|||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
)
|
||||
|
||||
|
|
@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource):
|
|||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
DocumentService.enrich_documents_with_summary_index_status(
|
||||
documents=documents,
|
||||
dataset=dataset,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
|
|
@ -797,6 +810,7 @@ class DocumentApi(DocumentResource):
|
|||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
|
|
@ -832,6 +846,7 @@ class DocumentApi(DocumentResource):
|
|||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
|
@ -1255,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
|||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task.delay(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get summary status detail from service
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
result = SummaryIndexService.get_document_summary_status_detail(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
return result, 200
|
||||
|
|
|
|||
|
|
@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
|||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
|
|
@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
|
|||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
|
|
@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"data": segments_with_summary,
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
|
|
@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
|
|
@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -14,13 +21,45 @@ from ..wraps import (
|
|||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel):
|
|||
retrieval_model: RetrievalModel | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
|
||||
class DatasetUpdatePayload(BaseModel):
|
||||
|
|
@ -217,6 +218,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
embedding_model_provider=payload.embedding_model_provider,
|
||||
embedding_model_name=payload.embedding_model,
|
||||
retrieval_model=payload.retrieval_model,
|
||||
summary_index_setting=payload.summary_index_setting,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||
Segmentation,
|
||||
)
|
||||
from services.file_service import FileService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class DocumentTextCreatePayload(BaseModel):
|
||||
|
|
@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource):
|
|||
)
|
||||
documents = paginated_documents.items
|
||||
|
||||
DocumentService.enrich_documents_with_summary_index_status(
|
||||
documents=documents,
|
||||
dataset=dataset,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(documents, document_fields),
|
||||
"has_more": len(documents) == query_params.limit,
|
||||
|
|
@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource):
|
|||
if metadata not in self.METADATA_CHOICES:
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
# Calculate summary_index_status if needed
|
||||
summary_index_status = None
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
if has_summary_index and document.need_summary is True:
|
||||
summary_index_status = SummaryIndexService.get_document_summary_index_status(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
|
|
@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource):
|
|||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"summary_index_status": summary_index_status,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
|
|
@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource):
|
|||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"summary_index_status": summary_index_status,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC):
|
|||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"content": resource["content"],
|
||||
"summary": resource.get("summary"),
|
||||
}
|
||||
)
|
||||
metadata["retriever_resources"] = updated_resources
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||
|
||||
class PreviewDetail(BaseModel):
|
||||
content: str
|
||||
summary: str | None = None
|
||||
child_chunks: list[str] | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,8 @@ def download(f: File, /):
|
|||
):
|
||||
return _download_file_content(f.storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
|
@ -134,6 +136,8 @@ def _download_file_content(path: str, /):
|
|||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ Proxy requests to avoid SSRF
|
|||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
|
@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
|||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
Headers: TypeAlias = dict[str, str]
|
||||
_HEADERS_ADAPTER = TypeAdapter(Headers)
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
_SSRF_CLIENT_LIMITS = httpx.Limits(
|
||||
|
|
@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
|||
)
|
||||
|
||||
|
||||
def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
def _get_user_provided_host_header(headers: Headers | None) -> str | None:
|
||||
"""
|
||||
Extract the user-provided Host header from the headers dict.
|
||||
|
||||
|
|
@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
def _inject_trace_headers(headers: Headers | None) -> Headers:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
|
|
@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict:
|
|||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
|
|
@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
|
||||
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
if not isinstance(verify_option, bool):
|
||||
raise ValueError("ssl_verify must be a boolean")
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
try:
|
||||
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
|
||||
except ValidationError as e:
|
||||
raise ValueError("headers must be a mapping of string keys to string values") from e
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
|
|
@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("GET", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("POST", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("PUT", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
|
|
|||
|
|
@ -311,14 +311,18 @@ class IndexingRunner:
|
|||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
|
|
@ -361,6 +365,12 @@ class IndexingRunner:
|
|||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
|
|
|
|||
|
|
@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
|||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
"""Summarize the following content. Extract only the key information and main points. """
|
||||
"""Remove redundant details.
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
"""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import decimal
|
||||
import hashlib
|
||||
from threading import Lock
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
|
|
@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import (
|
|||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
|
|
@ -144,34 +148,60 @@ class AIModel(BaseModel):
|
|||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
# sort credentials
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
contexts.plugin_model_schemas.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_schemas.set({})
|
||||
contexts.plugin_model_schema_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_schema_lock.get():
|
||||
if cache_key in contexts.plugin_model_schemas.get():
|
||||
return contexts.plugin_model_schemas.get()[cache_key]
|
||||
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if schema:
|
||||
contexts.plugin_model_schemas.get()[cache_key] = schema
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
return schema
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@ import logging
|
|||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
|
@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
|||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -175,34 +180,60 @@ class ModelProviderFactory:
|
|||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
|
||||
# sort credentials
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
contexts.plugin_model_schemas.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_schemas.set({})
|
||||
contexts.plugin_model_schema_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_schema_lock.get():
|
||||
if cache_key in contexts.plugin_model_schemas.get():
|
||||
return contexts.plugin_model_schemas.get()[cache_key]
|
||||
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if schema:
|
||||
contexts.plugin_model_schemas.get()[cache_key] = schema
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
return schema
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode
|
|||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DocumentSegment,
|
||||
DocumentSegmentSummary,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
|
@ -389,15 +395,15 @@ class RetrievalService:
|
|||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids: list[Any] = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
summary_segment_ids = set() # Track segments retrieved via summary
|
||||
summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score
|
||||
|
||||
# First pass: collect all document IDs and identify summary documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
|
|
@ -408,16 +414,39 @@ class RetrievalService:
|
|||
continue
|
||||
valid_dataset_documents[document_id] = dataset_document
|
||||
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if original_chunk_id:
|
||||
summary_segment_ids.add(original_chunk_id)
|
||||
# Save summary's score for later use
|
||||
summary_score = document.metadata.get("score")
|
||||
if summary_score is not None:
|
||||
try:
|
||||
summary_score_float = float(summary_score)
|
||||
# If the same segment has multiple summary hits, take the highest score
|
||||
if original_chunk_id not in summary_score_map:
|
||||
summary_score_map[original_chunk_id] = summary_score_float
|
||||
else:
|
||||
summary_score_map[original_chunk_id] = max(
|
||||
summary_score_map[original_chunk_id], summary_score_float
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
# Skip invalid score values
|
||||
pass
|
||||
continue # Skip adding to other lists for summary documents
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
child_index_node_ids.append(doc_id)
|
||||
else:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
|
|
@ -433,6 +462,7 @@ class RetrievalService:
|
|||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
|
|
@ -447,6 +477,7 @@ class RetrievalService:
|
|||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
|
|
@ -470,6 +501,7 @@ class RetrievalService:
|
|||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
|
|
@ -481,6 +513,40 @@ class RetrievalService:
|
|||
if index_node_segments:
|
||||
segments.extend(index_node_segments)
|
||||
|
||||
# Handle summary documents: query segments by original_chunk_id
|
||||
if summary_segment_ids:
|
||||
summary_segment_ids_list = list(summary_segment_ids)
|
||||
summary_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id.in_(summary_segment_ids_list),
|
||||
)
|
||||
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
|
||||
segments.extend(summary_segments)
|
||||
# Add summary segment IDs to segment_ids for summary query
|
||||
for seg in summary_segments:
|
||||
if seg.id not in segment_ids:
|
||||
segment_ids.append(seg.id)
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
|
|
@ -489,30 +555,44 @@ class RetrievalService:
|
|||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
# Check if this segment was retrieved via summary
|
||||
# Use summary score as base score if available, otherwise 0.0
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
max_score = 0.0
|
||||
for child_chunk in child_chunks:
|
||||
document = doc_to_document_map[child_chunk.index_node_id]
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"score": child_score,
|
||||
}
|
||||
child_chunk_details.append(child_chunk_detail)
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||
max_score = max(max_score, child_score)
|
||||
for attachment_info in attachment_infos:
|
||||
file_document = doc_to_document_map[attachment_info["id"]]
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
file_document = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
else:
|
||||
# No child chunks or attachments, use summary score if available
|
||||
summary_score = summary_score_map.get(segment.id)
|
||||
if summary_score is not None:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
"segment": segment,
|
||||
}
|
||||
|
|
@ -520,14 +600,23 @@ class RetrievalService:
|
|||
else:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
max_score = 0.0
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
|
||||
# Check if this segment was retrieved via summary
|
||||
# Use summary score if available (summary retrieval takes priority)
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
# If not retrieved via summary, use original segment's score
|
||||
if segment.id not in summary_score_map:
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
|
||||
# Also consider attachment scores
|
||||
for attachment_info in attachment_infos:
|
||||
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
|
|
@ -576,9 +665,16 @@ class RetrievalService:
|
|||
else None
|
||||
)
|
||||
|
||||
# Extract summary if this segment was retrieved via summary
|
||||
summary_content = segment_summary_map.get(segment.id)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
segment=segment,
|
||||
child_chunks=child_chunks_list,
|
||||
score=score,
|
||||
files=files,
|
||||
summary=summary_content,
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
|
|
|
|||
|
|
@ -391,46 +391,78 @@ class QdrantVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs most similar by bm25.
|
||||
"""Return docs most similar by full-text search.
|
||||
|
||||
Searches each keyword separately and merges results to ensure documents
|
||||
matching ANY keyword are returned (OR logic). Results are capped at top_k.
|
||||
|
||||
Args:
|
||||
query: Search query text. Multi-word queries are split into keywords,
|
||||
with each keyword searched separately. Limited to 10 keywords.
|
||||
**kwargs: Additional search parameters (top_k, document_ids_filter)
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
List of up to top_k unique documents matching any query keyword.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
),
|
||||
]
|
||||
)
|
||||
# Build base must conditions (AND logic) for metadata filters
|
||||
base_must_conditions: list = [
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
]
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
if scroll_filter.must:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
base_must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get("top_k", 2),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(document)
|
||||
)
|
||||
|
||||
# Split query into keywords, deduplicate and limit to prevent DoS
|
||||
keywords = list(dict.fromkeys(kw.strip() for kw in query.strip().split() if kw.strip()))[:10]
|
||||
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
seen_ids: set[str | int] = set()
|
||||
documents: list[Document] = []
|
||||
|
||||
# Search each keyword separately and merge results.
|
||||
# This ensures each keyword gets its own search, preventing one keyword's
|
||||
# results from completely overshadowing another's due to scroll ordering.
|
||||
for keyword in keywords:
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
*base_must_conditions,
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=keyword),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
|
||||
for result in results:
|
||||
if result and result.id not in seen_ids:
|
||||
seen_ids.add(result.id)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(document)
|
||||
if len(documents) >= top_k:
|
||||
return documents
|
||||
|
||||
return documents
|
||||
|
||||
|
|
|
|||
|
|
@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
|
|||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
|
|
|||
|
|
@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel):
|
|||
doc_metadata: dict[str, Any] | None = None
|
||||
title: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
summary: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
"""Word (.docx) document extractor used for RAG ingestion.
|
||||
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
|
|
@ -8,7 +11,6 @@ import tempfile
|
|||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
|
@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
response = httpx.get(self.file_path, timeout=None)
|
||||
response = ssrf_proxy.get(self.file_path)
|
||||
|
||||
if response.status_code != 200:
|
||||
response.close()
|
||||
|
|
@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor):
|
|||
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
|
||||
try:
|
||||
self.temp_file.write(response.content)
|
||||
self.temp_file.flush()
|
||||
finally:
|
||||
response.close()
|
||||
self.file_path = self.temp_file.name
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
|
|||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
|
|
@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC):
|
|||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,27 @@
|
|||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
|
@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
|||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
from models import UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
|
@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
|
|
@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item."""
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_texts))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
|
||||
futures = [executor.submit(process, preview) for preview in preview_texts]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode (indexing-estimate), if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise ValueError(error_summary)
|
||||
|
||||
return preview_texts
|
||||
|
||||
@staticmethod
|
||||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
and supports vision models by including images from the segment attachments or text content.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
|
||||
Returns:
|
||||
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
|
||||
"""
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
|
||||
|
||||
model_name = summary_index_setting.get("model_name")
|
||||
model_provider_name = summary_index_setting.get("model_provider_name")
|
||||
summary_prompt = summary_index_setting.get("summary_prompt")
|
||||
|
||||
if not model_name or not model_provider_name:
|
||||
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
|
||||
|
||||
# Import default summary prompt
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle, model_name)
|
||||
|
||||
# Get model schema to check if vision is supported
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
|
||||
|
||||
# Extract images if model supports vision
|
||||
image_files = []
|
||||
if supports_vision:
|
||||
# First, try to get images from SegmentAttachmentBinding (preferred method)
|
||||
if segment_id:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
|
||||
|
||||
# If no images from attachments, fall back to extracting from text
|
||||
if not image_files:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
|
||||
|
||||
# Build prompt messages
|
||||
prompt_messages = []
|
||||
|
||||
if image_files:
|
||||
# If we have images, create a UserPromptMessage with both text and images
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
|
||||
# Add images first
|
||||
for file in image_files:
|
||||
try:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
prompt_message_contents.append(file_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
|
||||
continue
|
||||
|
||||
# Add text content
|
||||
if prompt_message_contents: # Only add text if we successfully added images
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
# If image conversion failed, fall back to text-only
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
else:
|
||||
# No images, use simple text prompt
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False
|
||||
)
|
||||
|
||||
# Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator
|
||||
if not isinstance(result, LLMResult):
|
||||
raise ValueError("Expected LLMResult when stream=False")
|
||||
|
||||
summary_content = getattr(result.message, "content", "")
|
||||
usage = result.usage
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
try:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
except Exception as e:
|
||||
# Log but don't fail summary generation if quota deduction fails
|
||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||
|
||||
return summary_content, usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
|
||||
"""
|
||||
Extract images from markdown text and convert them to File objects.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content that may contain markdown image links
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in the text
|
||||
"""
|
||||
# Extract markdown images using regex pattern
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
images = re.findall(pattern, text)
|
||||
|
||||
if not images:
|
||||
return []
|
||||
|
||||
upload_file_id_list = []
|
||||
|
||||
for image in images:
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
# Tool files are handled differently, skip for now
|
||||
continue
|
||||
|
||||
if not upload_file_id_list:
|
||||
return []
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create File objects from UploadFile records
|
||||
file_objects = []
|
||||
for upload_file in upload_files:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
mapping = {
|
||||
"upload_file_id": upload_file.id,
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
|
||||
"type": FileType.IMAGE.value,
|
||||
}
|
||||
|
||||
try:
|
||||
file_obj = build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
|
||||
"""
|
||||
Extract images from SegmentAttachmentBinding table (preferred method).
|
||||
This matches how DatasetRetrieval gets segment attachments.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
segment_id: Segment ID to fetch attachments for
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in segment attachments
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# Query attachments from SegmentAttachmentBinding table
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment_id,
|
||||
SegmentAttachmentBinding.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
|
||||
file_objects = []
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
"""Paragraph index processor."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
|
@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
|||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
|
@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
|
|
@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
|
||||
|
||||
Note: For parent-child structure, we only generate summaries for parent chunks.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item (parent chunk)."""
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_texts))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
|
||||
futures = [executor.submit(process, preview) for preview in preview_texts]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode (indexing-estimate), if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise ValueError(error_summary)
|
||||
|
||||
return preview_texts
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import pandas as pd
|
|||
from flask import Flask, current_app
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
|
@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Note: qa_model doesn't generate summaries, but we clean them for completeness
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
|
|
@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
Note: QA model uses question-answer pairs, which don't require summary generation.
|
||||
"""
|
||||
# QA model doesn't generate summaries, return as-is
|
||||
return preview_texts
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
|
|
|
|||
|
|
@ -236,20 +236,24 @@ class DatasetRetrieval:
|
|||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
# Build content: if summary exists, add it before the segment content
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
segment_content = segment.get_sign_content()
|
||||
|
||||
# If summary exists, prepend it to the content
|
||||
if record.summary:
|
||||
final_content = f"{record.summary}\n{segment_content}"
|
||||
else:
|
||||
final_content = segment_content
|
||||
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=final_content,
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
if vision_enabled:
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
|
|
@ -316,6 +320,9 @@ class DatasetRetrieval:
|
|||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
# Add summary if this segment was retrieved via summary
|
||||
if hasattr(record, "summary") and record.summary:
|
||||
source.summary = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
|
|
|
|||
|
|
@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
# Build content: if summary exists, add it before the segment content
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
segment_content = segment.get_sign_content()
|
||||
|
||||
# If summary exists, prepend it to the content
|
||||
if record.summary:
|
||||
final_content = f"{record.summary}\n{segment_content}"
|
||||
else:
|
||||
final_content = segment_content
|
||||
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=final_content,
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
|
|
@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
# Add summary if this segment was retrieved via summary
|
||||
if hasattr(record, "summary") and record.summary:
|
||||
source.summary = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
|
|
|
|||
|
|
@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
|||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
|
|||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
|
|
@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
|||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
|
|
@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
|||
)
|
||||
.scalar()
|
||||
)
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
document.need_summary = True
|
||||
else:
|
||||
document.need_summary = False
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
|
|
@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
|||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
|
|
@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
|||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info("No segments need summary generation for document %s", document.id)
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment.id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
logger.info(
|
||||
"Successfully generated summary index for %s segments in document %s",
|
||||
len(segments_to_process),
|
||||
document.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document.id)
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(
|
||||
"Queuing summary index generation task for document %s (production mode)",
|
||||
document.id,
|
||||
)
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info("Summary index generation task queued for document %s", document.id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset.id,
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode, if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise KnowledgeIndexNodeError(error_summary)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output["preview"]),
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset | None = None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
|
||||
matched_count,
|
||||
dataset.id,
|
||||
document.id,
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
|
|
|||
|
|
@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.get_sign_content()
|
||||
# Add summary if available
|
||||
if record.summary:
|
||||
source["summary"] = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
|
|
|
|||
|
|
@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||
if "content" not in item:
|
||||
raise InvalidContextStructureError(f"Invalid context structure: {item}")
|
||||
|
||||
if item.get("summary"):
|
||||
context_str += item["summary"] + "\n"
|
||||
context_str += item["content"] + "\n"
|
||||
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
|
|
@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
files=context_dict.get("files"),
|
||||
summary=context_dict.get("summary"),
|
||||
)
|
||||
|
||||
return source
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
|
|||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
|
|||
"score_threshold_enabled": fields.Boolean,
|
||||
"score_threshold": fields.Float,
|
||||
}
|
||||
|
||||
dataset_summary_index_fields = {
|
||||
"enable": fields.Boolean,
|
||||
"model_name": fields.String,
|
||||
"model_provider_name": fields.String,
|
||||
"summary_prompt": fields.String,
|
||||
}
|
||||
|
||||
external_retrieval_model_fields = {
|
||||
"top_k": fields.Integer,
|
||||
"score_threshold": fields.Float,
|
||||
|
|
@ -83,6 +91,7 @@ dataset_detail_fields = {
|
|||
"embedding_model_provider": fields.String,
|
||||
"embedding_available": fields.Boolean,
|
||||
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
|
||||
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"doc_form": fields.String,
|
||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||
|
|
|
|||
|
|
@ -33,6 +33,11 @@ document_fields = {
|
|||
"hit_count": fields.Integer,
|
||||
"doc_form": fields.String,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status:
|
||||
# "SUMMARIZING" (when task is queued and generating)
|
||||
"summary_index_status": fields.String,
|
||||
# Whether this document needs summary index generation
|
||||
"need_summary": fields.Boolean,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
|
|
@ -60,6 +65,10 @@ document_with_segments_fields = {
|
|||
"completed_segments": fields.Integer,
|
||||
"total_segments": fields.Integer,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status:
|
||||
# "SUMMARIZING" (when task is queued and generating)
|
||||
"summary_index_status": fields.String,
|
||||
"need_summary": fields.Boolean, # Whether this document needs summary index generation
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
|
|
|
|||
|
|
@ -58,4 +58,5 @@ hit_testing_record_fields = {
|
|||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
"summary": fields.String, # Summary content if retrieved via summary index
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel):
|
|||
segment_position: int | None = None
|
||||
index_node_hash: str | None = None
|
||||
content: str | None = None
|
||||
summary: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
|
|
|
|||
|
|
@ -49,4 +49,5 @@ segment_fields = {
|
|||
"stopped_at": TimestampField,
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"attachments": fields.List(fields.Nested(attachment_fields)),
|
||||
"summary": fields.String, # Summary content for the segment
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
"""add summary index feature
|
||||
|
||||
Revision ID: 788d3099ae3a
|
||||
Revises: 9d77545f524e
|
||||
Create Date: 2026-01-27 18:15:45.277928
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '788d3099ae3a'
|
||||
down_revision = '9d77545f524e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
if _is_pg(conn):
|
||||
op.create_table('document_segment_summaries',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('summary_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey')
|
||||
)
|
||||
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
op.create_table(
|
||||
'document_segment_summaries',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('summary_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'),
|
||||
)
|
||||
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_column('need_summary')
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('summary_index_setting')
|
||||
|
||||
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_segment_summaries_status_idx')
|
||||
batch_op.drop_index('document_segment_summaries_document_id_idx')
|
||||
batch_op.drop_index('document_segment_summaries_dataset_id_idx')
|
||||
batch_op.drop_index('document_segment_summaries_chunk_id_idx')
|
||||
|
||||
op.drop_table('document_segment_summaries')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -72,6 +72,7 @@ class Dataset(Base):
|
|||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
|
||||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
|
|
@ -419,6 +420,7 @@ class Document(Base):
|
|||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
|
||||
|
||||
|
|
@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base):
|
|||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DocumentSegmentSummary(Base):
|
||||
__tablename__ = "document_segment_summaries"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
|
||||
sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_segment_summaries_document_id_idx", "document_id"),
|
||||
sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"),
|
||||
sa.Index("document_segment_summaries_status_idx", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# corresponds to DocumentSegment.id or parent chunk id
|
||||
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"
|
||||
|
|
|
|||
|
|
@ -657,16 +657,22 @@ class AccountTrialAppRecord(Base):
|
|||
return user
|
||||
|
||||
|
||||
class ExporleBanner(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
content = mapped_column(sa.JSON, nullable=False)
|
||||
link = mapped_column(String(255), nullable=False)
|
||||
sort = mapped_column(sa.Integer, nullable=False)
|
||||
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
language: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
|
||||
)
|
||||
|
||||
|
||||
class OAuthProviderApp(TypeBase):
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
|
|||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.regenerate_summary_index_task import regenerate_summary_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
|
|
@ -211,6 +212,7 @@ class DatasetService:
|
|||
embedding_model_provider: str | None = None,
|
||||
embedding_model_name: str | None = None,
|
||||
retrieval_model: RetrievalModel | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
|
|
@ -253,6 +255,8 @@ class DatasetService:
|
|||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
dataset.provider = provider
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
|
|
@ -476,6 +480,11 @@ class DatasetService:
|
|||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
|
||||
# Update summary index setting if provided
|
||||
summary_index_setting = data.get("summary_index_setting", None)
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
|
||||
# Update basic dataset properties
|
||||
dataset.name = data.get("name", dataset.name)
|
||||
dataset.description = data.get("description", dataset.description)
|
||||
|
|
@ -564,6 +573,9 @@ class DatasetService:
|
|||
# update Retrieval model
|
||||
if data.get("retrieval_model"):
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# update summary index setting
|
||||
if data.get("summary_index_setting"):
|
||||
filtered_data["summary_index_setting"] = data.get("summary_index_setting")
|
||||
# update icon info
|
||||
if data.get("icon_info"):
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
|
|
@ -572,12 +584,27 @@ class DatasetService:
|
|||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
db.session.refresh(dataset)
|
||||
|
||||
# update pipeline knowledge base node data
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
|
||||
|
||||
# Trigger vector index task if indexing technique changed
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||
# If embedding_model changed, also regenerate summary vectors
|
||||
if action == "update":
|
||||
regenerate_summary_index_task.delay(
|
||||
dataset.id,
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries.
|
||||
# The new setting will only apply to:
|
||||
# 1. New documents added after the setting change
|
||||
# 2. Manual summary generation requests
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -616,6 +643,7 @@ class DatasetService:
|
|||
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
|
||||
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
|
||||
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
|
||||
knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
|
||||
node["data"] = knowledge_index_node_data
|
||||
updated = True
|
||||
except Exception:
|
||||
|
|
@ -854,6 +882,54 @@ class DatasetService:
|
|||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
|
||||
@staticmethod
|
||||
def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if summary_index_setting model (model_name or model_provider_name) has changed.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
|
||||
Returns:
|
||||
bool: True if summary model changed, False otherwise
|
||||
"""
|
||||
# Check if summary_index_setting is being updated
|
||||
if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
|
||||
return False
|
||||
|
||||
new_summary_setting = data.get("summary_index_setting")
|
||||
old_summary_setting = dataset.summary_index_setting
|
||||
|
||||
# If new setting is disabled, no need to regenerate
|
||||
if not new_summary_setting or not new_summary_setting.get("enable"):
|
||||
return False
|
||||
|
||||
# If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate)
|
||||
# Note: This task only regenerates existing summaries, not generates new ones
|
||||
if not old_summary_setting:
|
||||
return False
|
||||
|
||||
# Compare model_name and model_provider_name
|
||||
old_model_name = old_summary_setting.get("model_name")
|
||||
old_model_provider = old_summary_setting.get("model_provider_name")
|
||||
new_model_name = new_summary_setting.get("model_name")
|
||||
new_model_provider = new_summary_setting.get("model_provider_name")
|
||||
|
||||
# Check if model changed
|
||||
if old_model_name != new_model_name or old_model_provider != new_model_provider:
|
||||
logger.info(
|
||||
"Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
|
||||
dataset.id,
|
||||
old_model_provider,
|
||||
old_model_name,
|
||||
new_model_provider,
|
||||
new_model_name,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
|
|
@ -889,6 +965,9 @@ class DatasetService:
|
|||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
|
|
@ -994,6 +1073,9 @@ class DatasetService:
|
|||
if dataset.keyword_number != knowledge_configuration.keyword_number:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
|
|
@ -1314,6 +1396,50 @@ class DocumentService:
|
|||
upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
@staticmethod
|
||||
def enrich_documents_with_summary_index_status(
|
||||
documents: Sequence[Document],
|
||||
dataset: Dataset,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Enrich documents with summary_index_status based on dataset summary index settings.
|
||||
|
||||
This method calculates and sets the summary_index_status for each document that needs summary.
|
||||
Documents that don't need summary or when summary index is disabled will have status set to None.
|
||||
|
||||
Args:
|
||||
documents: List of Document instances to enrich
|
||||
dataset: Dataset instance containing summary_index_setting
|
||||
tenant_id: Tenant ID for summary status lookup
|
||||
"""
|
||||
# Check if dataset has summary index enabled
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
|
||||
# Filter documents that need summary calculation
|
||||
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
|
||||
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
|
||||
|
||||
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
|
||||
summary_status_map: dict[str, str | None] = {}
|
||||
if has_summary_index and document_ids_need_summary:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_status_map = SummaryIndexService.get_documents_summary_index_status(
|
||||
document_ids=document_ids_need_summary,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Add summary_index_status to each document
|
||||
for document in documents:
|
||||
if has_summary_index and document.need_summary is True:
|
||||
# Get status from map, default to None (not queued yet)
|
||||
document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined]
|
||||
else:
|
||||
# Return null if summary index is not enabled or document doesn't need summary
|
||||
document.summary_index_status = None # type: ignore[attr-defined]
|
||||
|
||||
@staticmethod
|
||||
def prepare_document_batch_download_zip(
|
||||
*,
|
||||
|
|
@ -1964,6 +2090,8 @@ class DocumentService:
|
|||
DuplicateDocumentIndexingTaskProxy(
|
||||
dataset.tenant_id, dataset.id, duplicate_document_ids
|
||||
).delay()
|
||||
# Note: Summary index generation is triggered in document_indexing_task after indexing completes
|
||||
# to ensure segments are available. See tasks/document_indexing_task.py
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
|
||||
|
|
@ -2268,6 +2396,11 @@ class DocumentService:
|
|||
name: str,
|
||||
batch: str,
|
||||
):
|
||||
# Set need_summary based on dataset's summary_index_setting
|
||||
need_summary = False
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
need_summary = True
|
||||
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
|
|
@ -2281,6 +2414,7 @@ class DocumentService:
|
|||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
doc_language=document_language,
|
||||
need_summary=need_summary,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if dataset.built_in_field_enabled:
|
||||
|
|
@ -2505,6 +2639,7 @@ class DocumentService:
|
|||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
summary_index_setting=knowledge_config.summary_index_setting,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
|
|
@ -2686,6 +2821,14 @@ class DocumentService:
|
|||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
# valid summary index setting
|
||||
summary_index_setting = args["process_rule"].get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
|
||||
raise ValueError("Summary index model name is required")
|
||||
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
|
||||
raise ValueError("Summary index model provider name is required")
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(
|
||||
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
|
||||
|
|
@ -3154,6 +3297,35 @@ class SegmentService:
|
|||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update summary index if summary is provided and has changed
|
||||
if args.summary is not None:
|
||||
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
|
||||
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
|
||||
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if summary has changed
|
||||
existing_summary_content = existing_summary.summary_content if existing_summary else None
|
||||
if existing_summary_content != args.summary:
|
||||
# Summary has changed, update it
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
|
|
@ -3228,6 +3400,73 @@ class SegmentService:
|
|||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# Handle summary index when content changed
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if args.summary is None:
|
||||
# User didn't provide summary, auto-regenerate if segment previously had summary
|
||||
# Auto-regeneration only happens if summary_index_setting exists and enable is True
|
||||
if (
|
||||
existing_summary
|
||||
and dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
):
|
||||
# Segment previously had summary, regenerate it with new content
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, dataset.summary_index_setting
|
||||
)
|
||||
logger.info("Auto-regenerated summary for segment %s after content change", segment.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to auto-regenerate summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary regeneration fails
|
||||
else:
|
||||
# User provided summary, check if it has changed
|
||||
# Manual summary updates are allowed even if summary_index_setting doesn't exist
|
||||
existing_summary_content = existing_summary.summary_content if existing_summary else None
|
||||
if existing_summary_content != args.summary:
|
||||
# Summary has changed, use user-provided summary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
logger.info("Updated summary for segment %s with user-provided content", segment.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
else:
|
||||
# Summary hasn't changed, regenerate based on new content
|
||||
# Auto-regeneration only happens if summary_index_setting exists and enable is True
|
||||
if (
|
||||
existing_summary
|
||||
and dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
):
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, dataset.summary_index_setting
|
||||
)
|
||||
logger.info(
|
||||
"Regenerated summary for segment %s after content change (summary unchanged)",
|
||||
segment.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to regenerate summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary regeneration fails
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
except Exception as e:
|
||||
|
|
@ -3616,6 +3855,39 @@ class SegmentService:
|
|||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
@classmethod
|
||||
def get_segments_by_document_and_dataset(
|
||||
cls,
|
||||
document_id: str,
|
||||
dataset_id: str,
|
||||
status: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> Sequence[DocumentSegment]:
|
||||
"""
|
||||
Get segments for a document in a dataset with optional filtering.
|
||||
|
||||
Args:
|
||||
document_id: Document ID
|
||||
dataset_id: Dataset ID
|
||||
status: Optional status filter (e.g., "completed")
|
||||
enabled: Optional enabled filter (True/False)
|
||||
|
||||
Returns:
|
||||
Sequence of DocumentSegment instances
|
||||
"""
|
||||
query = select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(DocumentSegment.status == status)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.where(DocumentSegment.enabled == enabled)
|
||||
|
||||
return db.session.scalars(query).all()
|
||||
|
||||
|
||||
class DatasetCollectionBindingService:
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel):
|
|||
data_source: DataSource | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
embedding_model: str | None = None
|
||||
|
|
@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel):
|
|||
regenerate_child_chunks: bool = False
|
||||
enabled: bool | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
|
|
|
|||
|
|
@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel):
|
|||
embedding_model: str = ""
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
# add summary index setting
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -343,6 +343,9 @@ class RagPipelineDslService:
|
|||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
|
|
@ -477,6 +480,9 @@ class RagPipelineDslService:
|
|||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||
)
|
||||
session.commit()
|
||||
|
||||
# Enable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
|
|
|
|||
|
|
@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
|||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ def delete_segment_from_index_task(
|
|||
doc_form = dataset_document.doc_form
|
||||
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
# For actual deletion, we should delete summaries (not just disable them)
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
|
|
@ -54,6 +55,7 @@ def delete_segment_from_index_task(
|
|||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
delete_summaries=True, # Actually delete summaries when segment is deleted
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
|
|
|
|||
|
|
@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str):
|
|||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
# Disable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
disabled_by=segment.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
|
|
|
|||
|
|
@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||
index_node_ids.extend(attachment_ids)
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
|
||||
# Disable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
# Get disabled_by from first segment (they should all have the same disabled_by)
|
||||
disabled_by = segments[0].disabled_by if segments else None
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for segments: %s", str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
# expire all session to get latest document's indexing status
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
if (
|
||||
document.indexing_status == "completed"
|
||||
and document.doc_form != "qa_model"
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: "
|
||||
"status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
# save vector index
|
||||
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||
|
||||
# Enable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||
# save vector index
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# Enable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for segments: %s", str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,119 @@
|
|||
"""Async task for generating summary indexes."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
|
||||
"""
|
||||
Async generate summary index for document segments.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
document_id: Document ID
|
||||
segment_ids: Optional list of specific segment IDs to process. If None, process all segments.
|
||||
|
||||
Usage:
|
||||
generate_summary_index_task.delay(dataset_id, document_id)
|
||||
generate_summary_index_task.delay(dataset_id, document_id, segment_ids)
|
||||
"""
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Start generating summary index for document {document_id} in dataset {dataset_id}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
|
||||
document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
if not document:
|
||||
logger.error(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
# Check if document needs summary
|
||||
if not document.need_summary:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary generation for document {document_id}: need_summary is False",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary generation for dataset {dataset_id}: "
|
||||
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index is disabled for dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
# Generate summaries
|
||||
summary_records = SummaryIndexService.generate_summaries_for_document(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_ids=segment_ids,
|
||||
only_parent_chunks=only_parent_chunks,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index generation completed for document {document_id}: "
|
||||
f"{len(summary_records)} summaries generated, latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate summary index for document %s", document_id)
|
||||
# Update document segments with error status if needed
|
||||
if segment_ids:
|
||||
error_message = f"Summary generation failed: {str(e)}"
|
||||
with session_factory.create_session() as session:
|
||||
session.query(DocumentSegment).filter(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.error: error_message,
|
||||
},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
|
@ -0,0 +1,315 @@
|
|||
"""Task for regenerating summary indexes when dataset settings change."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def regenerate_summary_index_task(
|
||||
dataset_id: str,
|
||||
regenerate_reason: str = "summary_model_changed",
|
||||
regenerate_vectors_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Regenerate summary indexes for all documents in a dataset.
|
||||
|
||||
This task is triggered when:
|
||||
1. summary_index_setting model changes (regenerate_reason="summary_model_changed")
|
||||
- Regenerates summary content and vectors for all existing summaries
|
||||
2. embedding_model changes (regenerate_reason="embedding_model_changed")
|
||||
- Only regenerates vectors for existing summaries (keeps summary content)
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed")
|
||||
regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content
|
||||
"""
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
|
||||
# Only regenerate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary regeneration for dataset {dataset_id}: "
|
||||
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Check if summary index is enabled (only for summary_model change)
|
||||
# For embedding_model change, we still re-vectorize existing summaries even if setting is disabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not regenerate_vectors_only:
|
||||
# For summary_model change, require summary_index_setting to be enabled
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index is disabled for dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
total_segments_processed = 0
|
||||
total_segments_failed = 0
|
||||
|
||||
if regenerate_vectors_only:
|
||||
# For embedding_model change: directly query all segments with existing summaries
|
||||
# Don't require document indexing_status == "completed"
|
||||
# Include summaries with status "completed" or "error" (if they have content)
|
||||
segments_with_summaries = (
|
||||
session.query(DocumentSegment, DocumentSegmentSummary)
|
||||
.join(
|
||||
DocumentSegmentSummary,
|
||||
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
|
||||
)
|
||||
.join(
|
||||
DatasetDocument,
|
||||
DocumentSegment.document_id == DatasetDocument.id,
|
||||
)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed", # Segment must be completed
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content
|
||||
# Include completed summaries or error summaries (with content)
|
||||
or_(
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.status == "error",
|
||||
),
|
||||
DatasetDocument.enabled == True, # Document must be enabled
|
||||
DatasetDocument.archived == False, # Document must not be archived
|
||||
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
|
||||
)
|
||||
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not segments_with_summaries:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"No segments with summaries found for re-vectorization in dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Found %s segments with summaries for re-vectorization in dataset %s",
|
||||
len(segments_with_summaries),
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
# Group by document for logging
|
||||
segments_by_document = defaultdict(list)
|
||||
for segment, summary_record in segments_with_summaries:
|
||||
segments_by_document[segment.document_id].append((segment, summary_record))
|
||||
|
||||
logger.info(
|
||||
"Segments grouped into %s documents for re-vectorization",
|
||||
len(segments_by_document),
|
||||
)
|
||||
|
||||
for document_id, segment_summary_pairs in segments_by_document.items():
|
||||
logger.info(
|
||||
"Re-vectorizing summaries for %s segments in document %s",
|
||||
len(segment_summary_pairs),
|
||||
document_id,
|
||||
)
|
||||
|
||||
for segment, summary_record in segment_summary_pairs:
|
||||
try:
|
||||
# Delete old vector
|
||||
if summary_record.summary_index_node_id:
|
||||
try:
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([summary_record.summary_index_node_id])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete old summary vector for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Re-vectorize with new embedding model
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
session.commit()
|
||||
total_segments_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to re-vectorize summary for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
total_segments_failed += 1
|
||||
# Update summary record with error status
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Re-vectorization failed: {str(e)}"
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
continue
|
||||
|
||||
else:
|
||||
# For summary_model change: require document indexing_status == "completed"
|
||||
# Get all documents with completed indexing status
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not dataset_documents:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"No documents found for summary regeneration in dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Found %s documents for summary regeneration in dataset %s",
|
||||
len(dataset_documents),
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# Skip qa_model documents
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get all segments with existing summaries
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.join(
|
||||
DocumentSegmentSummary,
|
||||
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
|
||||
)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not segments:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Regenerating summaries for %s segments in document %s",
|
||||
len(segments),
|
||||
dataset_document.id,
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
summary_record = None
|
||||
try:
|
||||
# Get existing summary record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
chunk_id=segment.id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not summary_record:
|
||||
logger.warning("Summary record not found for segment %s, skipping", segment.id)
|
||||
continue
|
||||
|
||||
# Regenerate both summary content and vectors (for summary_model change)
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, summary_index_setting
|
||||
)
|
||||
session.commit()
|
||||
total_segments_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to regenerate summary for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
total_segments_failed += 1
|
||||
# Update summary record with error status
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Regeneration failed: {str(e)}"
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process document %s for summary regeneration: %s",
|
||||
dataset_document.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
end_at = time.perf_counter()
|
||||
if regenerate_vectors_only:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary re-vectorization completed for dataset {dataset_id}: "
|
||||
f"{total_segments_processed} segments processed successfully, "
|
||||
f"{total_segments_failed} segments failed, "
|
||||
f"latency: {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index regeneration completed for dataset {dataset_id}: "
|
||||
f"{total_segments_processed} segments processed successfully, "
|
||||
f"{total_segments_failed} segments failed, "
|
||||
f"latency: {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Regenerate summary index failed for dataset %s", dataset_id)
|
||||
|
|
@ -46,6 +46,21 @@ def remove_document_from_index_task(document_id: str):
|
|||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
|
||||
# Disable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=document.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
|
|
@ -18,6 +20,10 @@ class QdrantVectorTest(AbstractVectorTest):
|
|||
api_key="difyai123456",
|
||||
),
|
||||
)
|
||||
# Additional doc IDs for multi-keyword search tests
|
||||
self.doc_apple_id = ""
|
||||
self.doc_banana_id = ""
|
||||
self.doc_both_id = ""
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
|
|
@ -27,6 +33,77 @@ class QdrantVectorTest(AbstractVectorTest):
|
|||
)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
def _create_document(self, content: str, doc_id: str) -> Document:
|
||||
"""Create a document with the given content and doc_id."""
|
||||
return Document(
|
||||
page_content=content,
|
||||
metadata={
|
||||
"doc_id": doc_id,
|
||||
"doc_hash": doc_id,
|
||||
"document_id": doc_id,
|
||||
"dataset_id": self.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
def setup_multi_keyword_documents(self):
|
||||
"""Create test documents with different keyword combinations for multi-keyword search tests."""
|
||||
self.doc_apple_id = str(uuid.uuid4())
|
||||
self.doc_banana_id = str(uuid.uuid4())
|
||||
self.doc_both_id = str(uuid.uuid4())
|
||||
|
||||
documents = [
|
||||
self._create_document("This document contains apple only", self.doc_apple_id),
|
||||
self._create_document("This document contains banana only", self.doc_banana_id),
|
||||
self._create_document("This document contains both apple and banana", self.doc_both_id),
|
||||
]
|
||||
embeddings = [self.example_embedding] * len(documents)
|
||||
|
||||
self.vector.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
def search_by_full_text_multi_keyword(self):
|
||||
"""Test multi-keyword search returns docs matching ANY keyword (OR logic)."""
|
||||
# First verify single keyword searches work correctly
|
||||
hits_apple = self.vector.search_by_full_text(query="apple", top_k=10)
|
||||
apple_ids = {doc.metadata["doc_id"] for doc in hits_apple}
|
||||
assert self.doc_apple_id in apple_ids, "Document with 'apple' should be found"
|
||||
assert self.doc_both_id in apple_ids, "Document with 'apple and banana' should be found"
|
||||
|
||||
hits_banana = self.vector.search_by_full_text(query="banana", top_k=10)
|
||||
banana_ids = {doc.metadata["doc_id"] for doc in hits_banana}
|
||||
assert self.doc_banana_id in banana_ids, "Document with 'banana' should be found"
|
||||
assert self.doc_both_id in banana_ids, "Document with 'apple and banana' should be found"
|
||||
|
||||
# Test multi-keyword search returns all matching documents
|
||||
hits = self.vector.search_by_full_text(query="apple banana", top_k=10)
|
||||
doc_ids = {doc.metadata["doc_id"] for doc in hits}
|
||||
|
||||
assert self.doc_apple_id in doc_ids, "Document with 'apple' should be found in multi-keyword search"
|
||||
assert self.doc_banana_id in doc_ids, "Document with 'banana' should be found in multi-keyword search"
|
||||
assert self.doc_both_id in doc_ids, "Document with both keywords should be found"
|
||||
# Expect 3 results: doc_apple (apple only), doc_banana (banana only), doc_both (contains both)
|
||||
assert len(hits) == 3, f"Expected 3 documents, got {len(hits)}"
|
||||
|
||||
# Test keyword order independence
|
||||
hits_ba = self.vector.search_by_full_text(query="banana apple", top_k=10)
|
||||
ids_ba = {doc.metadata["doc_id"] for doc in hits_ba}
|
||||
assert doc_ids == ids_ba, "Keyword order should not affect search results"
|
||||
|
||||
# Test no duplicates in results
|
||||
doc_id_list = [doc.metadata["doc_id"] for doc in hits]
|
||||
assert len(doc_id_list) == len(set(doc_id_list)), "Search results should not contain duplicates"
|
||||
|
||||
def run_all_tests(self):
|
||||
self.create_vector()
|
||||
self.search_by_vector()
|
||||
self.search_by_full_text()
|
||||
self.text_exists()
|
||||
self.get_ids_by_metadata_field()
|
||||
# Multi-keyword search tests
|
||||
self.setup_multi_keyword_documents()
|
||||
self.search_by_full_text_multi_keyword()
|
||||
# Cleanup - delete_vector() removes the entire collection
|
||||
self.delete_vector()
|
||||
|
||||
|
||||
def test_qdrant_vector(setup_mock_redis):
|
||||
QdrantVectorTest().run_all_tests()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from docx import Document
|
||||
|
|
@ -56,6 +58,42 @@ def test_parse_row():
|
|||
assert extractor._parse_row(row, {}, 3) == gt[idx]
|
||||
|
||||
|
||||
def test_init_downloads_via_ssrf_proxy(monkeypatch):
|
||||
doc = Document()
|
||||
doc.add_paragraph("hello")
|
||||
buf = io.BytesIO()
|
||||
doc.save(buf)
|
||||
docx_bytes = buf.getvalue()
|
||||
|
||||
calls: list[tuple[str, object]] = []
|
||||
|
||||
class FakeResponse:
|
||||
status_code = 200
|
||||
content = docx_bytes
|
||||
|
||||
def close(self) -> None:
|
||||
calls.append(("close", None))
|
||||
|
||||
def fake_get(url: str, **kwargs):
|
||||
calls.append(("get", (url, kwargs)))
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get))
|
||||
|
||||
extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id")
|
||||
try:
|
||||
assert calls
|
||||
assert calls[0][0] == "get"
|
||||
url, kwargs = calls[0][1]
|
||||
assert url == "https://example.com/test.docx"
|
||||
assert kwargs.get("timeout") is None
|
||||
assert extractor.web_path == "https://example.com/test.docx"
|
||||
assert extractor.file_path != extractor.web_path
|
||||
assert Path(extractor.file_path).read_bytes() == docx_bytes
|
||||
finally:
|
||||
extractor.temp_file.close()
|
||||
|
||||
|
||||
def test_extract_images_from_docx(monkeypatch):
|
||||
external_bytes = b"ext-bytes"
|
||||
internal_bytes = b"int-bytes"
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
|
|
@ -147,6 +148,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"model_manager": mock_model_manager,
|
||||
"get_binding": mock_get_binding,
|
||||
"task": mock_task,
|
||||
"regenerate_task": mock_regenerate_task,
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
|
|
@ -549,6 +551,13 @@ class TestDatasetServiceUpdateDataset:
|
|||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
|
||||
|
||||
# Verify regenerate summary index task was triggered (when embedding_model changes)
|
||||
mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with(
|
||||
"dataset-123",
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
|
|
|
|||
24
api/ty.toml
24
api/ty.toml
|
|
@ -1,11 +1,33 @@
|
|||
[src]
|
||||
exclude = [
|
||||
# TODO: enable when violations fixed
|
||||
# deps groups (A1/A2/B/C/D/E)
|
||||
# A1: foundational runtime typing / provider plumbing
|
||||
"core/mcp/session",
|
||||
"core/model_runtime/model_providers",
|
||||
"core/workflow/nodes/protocols.py",
|
||||
"libs/gmpy2_pkcs10aep_cipher.py",
|
||||
# A2: workflow engine/nodes
|
||||
"core/workflow",
|
||||
"core/app/workflow",
|
||||
"core/helper/code_executor",
|
||||
# B: app runner + prompt
|
||||
"core/prompt",
|
||||
"core/app/apps/base_app_runner.py",
|
||||
"core/app/apps/workflow_app_runner.py",
|
||||
# C: services/controllers/fields/libs
|
||||
"services",
|
||||
"controllers/console/app",
|
||||
"controllers/console/explore",
|
||||
"controllers/console/datasets",
|
||||
"controllers/console/workspace",
|
||||
"controllers/service_api/wraps.py",
|
||||
"fields/conversation_fields.py",
|
||||
"libs/external_api.py",
|
||||
# D: observability + integrations
|
||||
"core/ops",
|
||||
"extensions",
|
||||
# E: vector DB integrations
|
||||
"core/rag/datasource/vdb",
|
||||
# non-producition or generated code
|
||||
"migrations",
|
||||
"tests",
|
||||
|
|
|
|||
|
|
@ -397,7 +397,7 @@ WEB_API_CORS_ALLOW_ORIGINS=*
|
|||
# Specifies the allowed origins for cross-origin requests to the console API,
|
||||
# e.g. https://cloud.dify.ai or * for all origins.
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=*
|
||||
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). Leading dots are optional.
|
||||
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site's top-level domain (e.g., `example.com`). Leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
|
|
@ -1080,7 +1080,7 @@ ALIYUN_SLS_ENDPOINT=
|
|||
ALIYUN_SLS_REGION=
|
||||
# Aliyun SLS Project Name
|
||||
ALIYUN_SLS_PROJECT_NAME=
|
||||
# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage)
|
||||
# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage)
|
||||
ALIYUN_SLS_LOGSTORE_TTL=365
|
||||
# Enable dual-write to both SLS LogStore and SQL database (default: false)
|
||||
LOGSTORE_DUAL_WRITE_ENABLED=false
|
||||
|
|
@ -1375,6 +1375,7 @@ PLUGIN_DAEMON_PORT=5002
|
|||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://plugin_daemon:5002
|
||||
PLUGIN_MAX_PACKAGE_SIZE=52428800
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
PLUGIN_PPROF_ENABLED=false
|
||||
|
||||
PLUGIN_DEBUGGING_HOST=0.0.0.0
|
||||
|
|
|
|||
|
|
@ -589,6 +589,7 @@ x-shared-env: &shared-api-worker-env
|
|||
PLUGIN_DAEMON_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
|
||||
PLUGIN_DAEMON_URL: ${PLUGIN_DAEMON_URL:-http://plugin_daemon:5002}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL: ${PLUGIN_MODEL_SCHEMA_CACHE_TTL:-3600}
|
||||
PLUGIN_PPROF_ENABLED: ${PLUGIN_PPROF_ENABLED:-false}
|
||||
PLUGIN_DEBUGGING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0}
|
||||
PLUGIN_DEBUGGING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def parse_env_example(file_path):
|
|||
Parses the .env.example file and returns a dictionary with variable names as keys and default values as values.
|
||||
"""
|
||||
env_vars = {}
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line_number, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
# Ignore empty lines and comments
|
||||
|
|
@ -55,7 +55,7 @@ def insert_shared_env(template_path, output_path, shared_env_block, header_comme
|
|||
Inserts the shared environment variables block and header comments into the template file,
|
||||
removing any existing x-shared-env anchors, and generates the final docker-compose.yaml file.
|
||||
"""
|
||||
with open(template_path, "r") as f:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template_content = f.read()
|
||||
|
||||
# Remove existing x-shared-env: &shared-api-worker-env lines
|
||||
|
|
@ -69,7 +69,7 @@ def insert_shared_env(template_path, output_path, shared_env_block, header_comme
|
|||
# Prepare the final content with header comments and shared env block
|
||||
final_content = f"{header_comments}\n{shared_env_block}\n\n{template_content}"
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(final_content)
|
||||
print(f"Generated {output_path}")
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import { fetchWorkflowDraft } from '@/service/workflow'
|
|||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
import AppIcon from '../base/app-icon'
|
||||
import AppOperations from './app-operations'
|
||||
|
||||
|
|
@ -145,13 +146,8 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
appID: appDetail.id,
|
||||
include,
|
||||
})
|
||||
const a = document.createElement('a')
|
||||
const file = new Blob([data], { type: 'application/yaml' })
|
||||
const url = URL.createObjectURL(file)
|
||||
a.href = url
|
||||
a.download = `${appDetail.name}.yml`
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
downloadBlob({ data: file, fileName: `${appDetail.name}.yml` })
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) })
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/kn
|
|||
import { useInvalid } from '@/service/use-base'
|
||||
import { useExportPipelineDSL } from '@/service/use-pipeline'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
import ActionButton from '../../base/action-button'
|
||||
import Confirm from '../../base/confirm'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
|
||||
|
|
@ -64,13 +65,8 @@ const DropDown = ({
|
|||
pipelineId: pipeline_id,
|
||||
include,
|
||||
})
|
||||
const a = document.createElement('a')
|
||||
const file = new Blob([data], { type: 'application/yaml' })
|
||||
const url = URL.createObjectURL(file)
|
||||
a.href = url
|
||||
a.download = `${name}.pipeline`
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
downloadBlob({ data: file, fileName: `${name}.pipeline` })
|
||||
}
|
||||
catch {
|
||||
Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) })
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import { LanguagesSupported } from '@/i18n-config/language'
|
|||
import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation'
|
||||
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
import Button from '../../../base/button'
|
||||
import AddAnnotationModal from '../add-annotation-modal'
|
||||
import BatchAddModal from '../batch-add-annotation-modal'
|
||||
|
|
@ -56,28 +57,23 @@ const HeaderOptions: FC<Props> = ({
|
|||
)
|
||||
|
||||
const JSONLOutput = () => {
|
||||
const a = document.createElement('a')
|
||||
const content = listTransformer(list).join('\n')
|
||||
const file = new Blob([content], { type: 'application/jsonl' })
|
||||
const url = URL.createObjectURL(file)
|
||||
a.href = url
|
||||
a.download = `annotations-${locale}.jsonl`
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
downloadBlob({ data: file, fileName: `annotations-${locale}.jsonl` })
|
||||
}
|
||||
|
||||
const fetchList = async () => {
|
||||
const fetchList = React.useCallback(async () => {
|
||||
const { data }: any = await fetchExportAnnotationList(appId)
|
||||
setList(data as AnnotationItemBasic[])
|
||||
}
|
||||
}, [appId])
|
||||
|
||||
useEffect(() => {
|
||||
fetchList()
|
||||
}, [])
|
||||
}, [fetchList])
|
||||
useEffect(() => {
|
||||
if (controlUpdateList)
|
||||
fetchList()
|
||||
}, [controlUpdateList])
|
||||
}, [controlUpdateList, fetchList])
|
||||
|
||||
const [showBulkImportModal, setShowBulkImportModal] = useState(false)
|
||||
const [showClearConfirm, setShowClearConfirm] = useState(false)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import type { ReactNode } from 'react'
|
|||
import type { IConfigVarProps } from './index'
|
||||
import type { ExternalDataTool } from '@/models/common'
|
||||
import type { PromptVariable } from '@/models/debug'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { vi } from 'vitest'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
|
@ -240,7 +240,9 @@ describe('ConfigVar', () => {
|
|||
const saveButton = await screen.findByRole('button', { name: 'common.operation.save' })
|
||||
fireEvent.click(saveButton)
|
||||
|
||||
expect(onPromptVariablesChange).toHaveBeenCalledTimes(1)
|
||||
await waitFor(() => {
|
||||
expect(onPromptVariablesChange).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error when variable key is duplicated', async () => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import AutomaticBtn from './automatic-btn'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('AutomaticBtn', () => {
|
||||
const mockOnClick = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render the button with correct text', () => {
|
||||
render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
expect(screen.getByText('operation.automatic')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the sparkling icon', () => {
|
||||
const { container } = render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
// The icon should be an SVG element inside the button
|
||||
const svg = container.querySelector('svg')
|
||||
expect(svg).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should render as a button element', () => {
|
||||
render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
expect(screen.getByRole('button')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call onClick when button is clicked', () => {
|
||||
render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
fireEvent.click(button)
|
||||
|
||||
expect(mockOnClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onClick multiple times on multiple clicks', () => {
|
||||
render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
|
||||
fireEvent.click(button)
|
||||
fireEvent.click(button)
|
||||
fireEvent.click(button)
|
||||
|
||||
expect(mockOnClick).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Styling', () => {
|
||||
it('should have secondary-accent variant', () => {
|
||||
render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
expect(button.className).toContain('secondary-accent')
|
||||
})
|
||||
|
||||
it('should have small size', () => {
|
||||
render(<AutomaticBtn onClick={mockOnClick} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
expect(button.className).toContain('small')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
import type { App } from '@/types/app'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import EmptyElement from './empty-element'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
Trans: ({ i18nKey, components }: { i18nKey: string, components: Record<string, React.ReactNode> }) => (
|
||||
<span data-testid="trans-component" data-i18n-key={i18nKey}>
|
||||
{i18nKey}
|
||||
{components.shareLink}
|
||||
{components.testLink}
|
||||
</span>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/app-redirection', () => ({
|
||||
getRedirectionPath: (isTest: boolean, _app: App) => isTest ? '/test-path' : '/prod-path',
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/var', () => ({
|
||||
basePath: '/base',
|
||||
}))
|
||||
|
||||
describe('EmptyElement', () => {
|
||||
const createMockAppDetail = (mode: AppModeEnum) => ({
|
||||
id: 'test-app-id',
|
||||
name: 'Test App',
|
||||
description: 'Test description',
|
||||
mode,
|
||||
icon_type: 'emoji',
|
||||
icon: 'test-icon',
|
||||
icon_background: '#ffffff',
|
||||
enable_site: true,
|
||||
enable_api: true,
|
||||
created_at: Date.now(),
|
||||
site: {
|
||||
access_token: 'test-token',
|
||||
app_base_url: 'https://app.example.com',
|
||||
},
|
||||
}) as unknown as App
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render empty element with title', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
expect(screen.getByText('table.empty.element.title')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render Trans component with i18n key', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const transComponent = screen.getByTestId('trans-component')
|
||||
expect(transComponent).toHaveAttribute('data-i18n-key', 'table.empty.element.content')
|
||||
})
|
||||
|
||||
it('should render ThreeDotsIcon SVG', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.CHAT)
|
||||
const { container } = render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const svg = container.querySelector('svg')
|
||||
expect(svg).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('App Mode Handling', () => {
|
||||
it('should use CHAT mode for chat apps', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const link = screen.getAllByRole('link')[0]
|
||||
expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token')
|
||||
})
|
||||
|
||||
it('should use COMPLETION mode for completion apps', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.COMPLETION)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const link = screen.getAllByRole('link')[0]
|
||||
expect(link).toHaveAttribute('href', 'https://app.example.com/base/completion/test-token')
|
||||
})
|
||||
|
||||
it('should use WORKFLOW mode for workflow apps', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.WORKFLOW)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const link = screen.getAllByRole('link')[0]
|
||||
expect(link).toHaveAttribute('href', 'https://app.example.com/base/workflow/test-token')
|
||||
})
|
||||
|
||||
it('should use CHAT mode for advanced-chat apps', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.ADVANCED_CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const link = screen.getAllByRole('link')[0]
|
||||
expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token')
|
||||
})
|
||||
|
||||
it('should use CHAT mode for agent-chat apps', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.AGENT_CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const link = screen.getAllByRole('link')[0]
|
||||
expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Links', () => {
|
||||
it('should render share link with correct attributes', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const links = screen.getAllByRole('link')
|
||||
const shareLink = links[0]
|
||||
|
||||
expect(shareLink).toHaveAttribute('target', '_blank')
|
||||
expect(shareLink).toHaveAttribute('rel', 'noopener noreferrer')
|
||||
})
|
||||
|
||||
it('should render test link with redirection path', () => {
|
||||
const appDetail = createMockAppDetail(AppModeEnum.CHAT)
|
||||
render(<EmptyElement appDetail={appDetail} />)
|
||||
|
||||
const links = screen.getAllByRole('link')
|
||||
const testLink = links[1]
|
||||
|
||||
expect(testLink).toHaveAttribute('href', '/test-path')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
import type { QueryParam } from './index'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Filter, { TIME_PERIOD_MAPPING } from './filter'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { count?: number }) => {
|
||||
if (options?.count !== undefined)
|
||||
return `${key} (${options.count})`
|
||||
return key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-log', () => ({
|
||||
useAnnotationsCount: () => ({
|
||||
data: { count: 10 },
|
||||
isLoading: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('Filter', () => {
|
||||
const defaultQueryParams: QueryParam = {
|
||||
period: '9',
|
||||
annotation_status: 'all',
|
||||
keyword: '',
|
||||
}
|
||||
|
||||
const mockSetQueryParams = vi.fn()
|
||||
const defaultProps = {
|
||||
appId: 'test-app-id',
|
||||
queryParams: defaultQueryParams,
|
||||
setQueryParams: mockSetQueryParams,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render filter components', () => {
|
||||
render(<Filter {...defaultProps} />)
|
||||
|
||||
expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return null when loading', () => {
|
||||
// This test verifies the component renders correctly with the mocked data
|
||||
const { container } = render(<Filter {...defaultProps} />)
|
||||
expect(container.firstChild).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should render sort component in chat mode', () => {
|
||||
render(<Filter {...defaultProps} isChatMode />)
|
||||
|
||||
expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render sort component when not in chat mode', () => {
|
||||
render(<Filter {...defaultProps} isChatMode={false} />)
|
||||
|
||||
expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('TIME_PERIOD_MAPPING', () => {
|
||||
it('should have correct period keys', () => {
|
||||
expect(Object.keys(TIME_PERIOD_MAPPING)).toEqual(['1', '2', '3', '4', '5', '6', '7', '8', '9'])
|
||||
})
|
||||
|
||||
it('should have today period with value 0', () => {
|
||||
expect(TIME_PERIOD_MAPPING['1'].value).toBe(0)
|
||||
expect(TIME_PERIOD_MAPPING['1'].name).toBe('today')
|
||||
})
|
||||
|
||||
it('should have last7days period with value 7', () => {
|
||||
expect(TIME_PERIOD_MAPPING['2'].value).toBe(7)
|
||||
expect(TIME_PERIOD_MAPPING['2'].name).toBe('last7days')
|
||||
})
|
||||
|
||||
it('should have last4weeks period with value 28', () => {
|
||||
expect(TIME_PERIOD_MAPPING['3'].value).toBe(28)
|
||||
expect(TIME_PERIOD_MAPPING['3'].name).toBe('last4weeks')
|
||||
})
|
||||
|
||||
it('should have allTime period with value -1', () => {
|
||||
expect(TIME_PERIOD_MAPPING['9'].value).toBe(-1)
|
||||
expect(TIME_PERIOD_MAPPING['9'].name).toBe('allTime')
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should update keyword when typing in search input', () => {
|
||||
render(<Filter {...defaultProps} />)
|
||||
|
||||
const searchInput = screen.getByPlaceholderText('operation.search')
|
||||
fireEvent.change(searchInput, { target: { value: 'test search' } })
|
||||
|
||||
expect(mockSetQueryParams).toHaveBeenCalledWith({
|
||||
...defaultQueryParams,
|
||||
keyword: 'test search',
|
||||
})
|
||||
})
|
||||
|
||||
it('should clear keyword when clear button is clicked', () => {
|
||||
const propsWithKeyword = {
|
||||
...defaultProps,
|
||||
queryParams: { ...defaultQueryParams, keyword: 'existing search' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithKeyword} />)
|
||||
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
fireEvent.click(clearButton)
|
||||
|
||||
expect(mockSetQueryParams).toHaveBeenCalledWith({
|
||||
...defaultQueryParams,
|
||||
keyword: '',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query Params', () => {
|
||||
it('should display "today" when period is set to 1', () => {
|
||||
const propsWithPeriod = {
|
||||
...defaultProps,
|
||||
queryParams: { ...defaultQueryParams, period: '1' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithPeriod} />)
|
||||
|
||||
// Period '1' maps to 'today' in TIME_PERIOD_MAPPING
|
||||
expect(screen.getByText('filter.period.today')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display "last7days" when period is set to 2', () => {
|
||||
const propsWithPeriod = {
|
||||
...defaultProps,
|
||||
queryParams: { ...defaultQueryParams, period: '2' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithPeriod} />)
|
||||
|
||||
expect(screen.getByText('filter.period.last7days')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display "allTime" when period is set to 9', () => {
|
||||
render(<Filter {...defaultProps} />)
|
||||
|
||||
// Default period is '9' which maps to 'allTime'
|
||||
expect(screen.getByText('filter.period.allTime')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display annotated status with count when annotation_status is annotated', () => {
|
||||
const propsWithAnnotation = {
|
||||
...defaultProps,
|
||||
queryParams: { ...defaultQueryParams, annotation_status: 'annotated' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithAnnotation} />)
|
||||
|
||||
// The mock returns count: 10, so the text should include the count
|
||||
expect(screen.getByText('filter.annotation.annotated (10)')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display not_annotated status when annotation_status is not_annotated', () => {
|
||||
const propsWithNotAnnotated = {
|
||||
...defaultProps,
|
||||
queryParams: { ...defaultQueryParams, annotation_status: 'not_annotated' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithNotAnnotated} />)
|
||||
|
||||
expect(screen.getByText('filter.annotation.not_annotated')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should display all annotation status when annotation_status is all', () => {
|
||||
render(<Filter {...defaultProps} />)
|
||||
|
||||
// Default annotation_status is 'all'
|
||||
expect(screen.getByText('filter.annotation.all')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Chat Mode', () => {
|
||||
it('should display sort component with sort_by parameter', () => {
|
||||
const propsWithSort = {
|
||||
...defaultProps,
|
||||
isChatMode: true,
|
||||
queryParams: { ...defaultQueryParams, sort_by: 'created_at' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithSort} />)
|
||||
|
||||
expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle descending sort order', () => {
|
||||
const propsWithDescSort = {
|
||||
...defaultProps,
|
||||
isChatMode: true,
|
||||
queryParams: { ...defaultQueryParams, sort_by: '-created_at' },
|
||||
}
|
||||
|
||||
render(<Filter {...propsWithDescSort} />)
|
||||
|
||||
expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,221 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import ModelInfo from './model-info'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useTextGenerationCurrentProviderAndModelAndModelList: () => ({
|
||||
currentModel: {
|
||||
model: 'gpt-4',
|
||||
model_display_name: 'GPT-4',
|
||||
},
|
||||
currentProvider: {
|
||||
provider: 'openai',
|
||||
label: 'OpenAI',
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-icon', () => ({
|
||||
default: ({ modelName }: { provider: unknown, modelName: string }) => (
|
||||
<div data-testid="model-icon" data-model-name={modelName}>ModelIcon</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-name', () => ({
|
||||
default: ({ modelItem, showMode }: { modelItem: { model: string }, showMode: boolean }) => (
|
||||
<div data-testid="model-name" data-show-mode={showMode ? 'true' : 'false'}>
|
||||
{modelItem?.model}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
|
||||
<div data-testid="portal-elem" data-open={open ? 'true' : 'false'}>{children}</div>
|
||||
),
|
||||
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
|
||||
<div data-testid="portal-trigger" onClick={onClick}>{children}</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="portal-content">{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('ModelInfo', () => {
|
||||
const defaultModel = {
|
||||
name: 'gpt-4',
|
||||
provider: 'openai',
|
||||
completion_params: {
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.1,
|
||||
max_tokens: 2048,
|
||||
stop: ['END'],
|
||||
},
|
||||
}
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render model icon', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByTestId('model-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render model name', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByTestId('model-name')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('model-name')).toHaveTextContent('gpt-4')
|
||||
})
|
||||
|
||||
it('should render info icon button', () => {
|
||||
const { container } = render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
// The info button should contain an SVG icon
|
||||
const svgs = container.querySelectorAll('svg')
|
||||
expect(svgs.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should show model name with showMode prop', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByTestId('model-name')).toHaveAttribute('data-show-mode', 'true')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Info Panel Toggle', () => {
|
||||
it('should be closed by default', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false')
|
||||
})
|
||||
|
||||
it('should open when info button is clicked', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
const trigger = screen.getByTestId('portal-trigger')
|
||||
fireEvent.click(trigger)
|
||||
|
||||
expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true')
|
||||
})
|
||||
|
||||
it('should close when info button is clicked again', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
const trigger = screen.getByTestId('portal-trigger')
|
||||
|
||||
// Open
|
||||
fireEvent.click(trigger)
|
||||
expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true')
|
||||
|
||||
// Close
|
||||
fireEvent.click(trigger)
|
||||
expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model Parameters Display', () => {
|
||||
it('should render model params header', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByText('detail.modelParams')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render temperature parameter', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByText('Temperature')).toBeInTheDocument()
|
||||
expect(screen.getByText('0.7')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render top_p parameter', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByText('Top P')).toBeInTheDocument()
|
||||
expect(screen.getByText('0.9')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render presence_penalty parameter', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByText('Presence Penalty')).toBeInTheDocument()
|
||||
expect(screen.getByText('0.1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render max_tokens parameter', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByText('Max Token')).toBeInTheDocument()
|
||||
expect(screen.getByText('2048')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render stop parameter as comma-separated values', () => {
|
||||
render(<ModelInfo model={defaultModel} />)
|
||||
|
||||
expect(screen.getByText('Stop')).toBeInTheDocument()
|
||||
expect(screen.getByText('END')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Missing Parameters', () => {
|
||||
it('should show dash for missing parameters', () => {
|
||||
const modelWithNoParams = {
|
||||
name: 'gpt-4',
|
||||
provider: 'openai',
|
||||
completion_params: {},
|
||||
}
|
||||
|
||||
render(<ModelInfo model={modelWithNoParams} />)
|
||||
|
||||
const dashes = screen.getAllByText('-')
|
||||
expect(dashes.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should show dash for non-array stop values', () => {
|
||||
const modelWithInvalidStop = {
|
||||
name: 'gpt-4',
|
||||
provider: 'openai',
|
||||
completion_params: {
|
||||
stop: 'not-an-array',
|
||||
},
|
||||
}
|
||||
|
||||
render(<ModelInfo model={modelWithInvalidStop} />)
|
||||
|
||||
const stopValues = screen.getAllByText('-')
|
||||
expect(stopValues.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should join array stop values with comma', () => {
|
||||
const modelWithMultipleStops = {
|
||||
name: 'gpt-4',
|
||||
provider: 'openai',
|
||||
completion_params: {
|
||||
stop: ['END', 'STOP', 'DONE'],
|
||||
},
|
||||
}
|
||||
|
||||
render(<ModelInfo model={modelWithMultipleStops} />)
|
||||
|
||||
expect(screen.getByText('END,STOP,DONE')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model without completion_params', () => {
|
||||
it('should handle undefined completion_params', () => {
|
||||
const modelWithNoCompletionParams = {
|
||||
name: 'gpt-4',
|
||||
provider: 'openai',
|
||||
}
|
||||
|
||||
render(<ModelInfo model={modelWithNoCompletionParams} />)
|
||||
|
||||
expect(screen.getByTestId('model-icon')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import VarPanel from './var-panel'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/image-uploader/image-preview', () => ({
|
||||
default: ({ url, title, onCancel }: { url: string, title: string, onCancel: () => void }) => (
|
||||
<div data-testid="image-preview" data-url={url} data-title={title}>
|
||||
<button onClick={onCancel} data-testid="close-preview">Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('VarPanel', () => {
|
||||
const defaultProps = {
|
||||
varList: [
|
||||
{ label: 'name', value: 'John Doe' },
|
||||
{ label: 'age', value: '25' },
|
||||
],
|
||||
message_files: [],
|
||||
}
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render variables section header', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('detail.variables')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render variable labels with braces', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('name')).toBeInTheDocument()
|
||||
expect(screen.getByText('age')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render variable values', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
expect(screen.getByText('25')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render opening and closing braces', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
const openingBraces = screen.getAllByText('{{')
|
||||
const closingBraces = screen.getAllByText('}}')
|
||||
|
||||
expect(openingBraces.length).toBe(2)
|
||||
expect(closingBraces.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should render Variable02 icon', () => {
|
||||
const { container } = render(<VarPanel {...defaultProps} />)
|
||||
|
||||
const svg = container.querySelector('svg')
|
||||
expect(svg).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Collapse/Expand', () => {
|
||||
it('should show expanded state by default', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
expect(screen.getByText('25')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should collapse when header is clicked', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
const header = screen.getByText('detail.variables').closest('div')
|
||||
fireEvent.click(header!)
|
||||
|
||||
expect(screen.queryByText('John Doe')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('25')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should expand when clicked again', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
const header = screen.getByText('detail.variables').closest('div')
|
||||
|
||||
// Collapse
|
||||
fireEvent.click(header!)
|
||||
expect(screen.queryByText('John Doe')).not.toBeInTheDocument()
|
||||
|
||||
// Expand
|
||||
fireEvent.click(header!)
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show arrow icon when collapsed', () => {
|
||||
const { container } = render(<VarPanel {...defaultProps} />)
|
||||
|
||||
const header = screen.getByText('detail.variables').closest('div')
|
||||
fireEvent.click(header!)
|
||||
|
||||
// When collapsed, there should be SVG icons in the component
|
||||
const svgs = container.querySelectorAll('svg')
|
||||
expect(svgs.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should show arrow icon when expanded', () => {
|
||||
const { container } = render(<VarPanel {...defaultProps} />)
|
||||
|
||||
// When expanded, there should be SVG icons in the component
|
||||
const svgs = container.querySelectorAll('svg')
|
||||
expect(svgs.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Message Files', () => {
|
||||
it('should not render images section when message_files is empty', () => {
|
||||
render(<VarPanel {...defaultProps} />)
|
||||
|
||||
expect(screen.queryByText('detail.uploadImages')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render images section when message_files has items', () => {
|
||||
const propsWithFiles = {
|
||||
...defaultProps,
|
||||
message_files: ['https://example.com/image1.jpg', 'https://example.com/image2.jpg'],
|
||||
}
|
||||
|
||||
render(<VarPanel {...propsWithFiles} />)
|
||||
|
||||
expect(screen.getByText('detail.uploadImages')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render image thumbnails with correct background', () => {
|
||||
const propsWithFiles = {
|
||||
...defaultProps,
|
||||
message_files: ['https://example.com/image1.jpg'],
|
||||
}
|
||||
|
||||
const { container } = render(<VarPanel {...propsWithFiles} />)
|
||||
|
||||
const thumbnail = container.querySelector('[style*="background-image"]')
|
||||
expect(thumbnail).toBeInTheDocument()
|
||||
expect(thumbnail).toHaveStyle({ backgroundImage: 'url(https://example.com/image1.jpg)' })
|
||||
})
|
||||
|
||||
it('should open image preview when thumbnail is clicked', () => {
|
||||
const propsWithFiles = {
|
||||
...defaultProps,
|
||||
message_files: ['https://example.com/image1.jpg'],
|
||||
}
|
||||
|
||||
const { container } = render(<VarPanel {...propsWithFiles} />)
|
||||
|
||||
const thumbnail = container.querySelector('[style*="background-image"]')
|
||||
fireEvent.click(thumbnail!)
|
||||
|
||||
expect(screen.getByTestId('image-preview')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('image-preview')).toHaveAttribute('data-url', 'https://example.com/image1.jpg')
|
||||
})
|
||||
|
||||
it('should close image preview when close button is clicked', () => {
|
||||
const propsWithFiles = {
|
||||
...defaultProps,
|
||||
message_files: ['https://example.com/image1.jpg'],
|
||||
}
|
||||
|
||||
const { container } = render(<VarPanel {...propsWithFiles} />)
|
||||
|
||||
// Open preview
|
||||
const thumbnail = container.querySelector('[style*="background-image"]')
|
||||
fireEvent.click(thumbnail!)
|
||||
|
||||
expect(screen.getByTestId('image-preview')).toBeInTheDocument()
|
||||
|
||||
// Close preview
|
||||
act(() => {
|
||||
fireEvent.click(screen.getByTestId('close-preview'))
|
||||
})
|
||||
|
||||
expect(screen.queryByTestId('image-preview')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Empty State', () => {
|
||||
it('should render with empty varList', () => {
|
||||
const emptyProps = {
|
||||
varList: [],
|
||||
message_files: [],
|
||||
}
|
||||
|
||||
render(<VarPanel {...emptyProps} />)
|
||||
|
||||
expect(screen.getByText('detail.variables')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple Images', () => {
|
||||
it('should render multiple image thumbnails', () => {
|
||||
const propsWithMultipleFiles = {
|
||||
...defaultProps,
|
||||
message_files: [
|
||||
'https://example.com/image1.jpg',
|
||||
'https://example.com/image2.jpg',
|
||||
'https://example.com/image3.jpg',
|
||||
],
|
||||
}
|
||||
|
||||
const { container } = render(<VarPanel {...propsWithMultipleFiles} />)
|
||||
|
||||
const thumbnails = container.querySelectorAll('[style*="background-image"]')
|
||||
expect(thumbnails.length).toBe(3)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,390 @@
|
|||
import type { AppDetailResponse } from '@/models/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import TriggerCard from './trigger-card'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { count?: number }) => {
|
||||
if (options?.count !== undefined)
|
||||
return `${key} (${options.count})`
|
||||
return key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
isCurrentWorkspaceEditor: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => (path: string) => `https://docs.example.com${path}`,
|
||||
}))
|
||||
|
||||
const mockSetTriggerStatus = vi.fn()
|
||||
const mockSetTriggerStatuses = vi.fn()
|
||||
vi.mock('@/app/components/workflow/store/trigger-status', () => ({
|
||||
useTriggerStatusStore: () => ({
|
||||
setTriggerStatus: mockSetTriggerStatus,
|
||||
setTriggerStatuses: mockSetTriggerStatuses,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUpdateTriggerStatus = vi.fn()
|
||||
const mockInvalidateAppTriggers = vi.fn()
|
||||
let mockTriggers: Array<{
|
||||
id: string
|
||||
node_id: string
|
||||
title: string
|
||||
trigger_type: string
|
||||
status: string
|
||||
provider_name?: string
|
||||
}> = []
|
||||
let mockIsLoading = false
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useAppTriggers: () => ({
|
||||
data: { data: mockTriggers },
|
||||
isLoading: mockIsLoading,
|
||||
}),
|
||||
useUpdateTriggerStatus: () => ({
|
||||
mutateAsync: mockUpdateTriggerStatus,
|
||||
}),
|
||||
useInvalidateAppTriggers: () => mockInvalidateAppTriggers,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-triggers', () => ({
|
||||
useAllTriggerPlugins: () => ({
|
||||
data: [
|
||||
{ id: 'plugin-1', name: 'Test Plugin', icon: 'test-icon' },
|
||||
],
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils', () => ({
|
||||
canFindTool: () => false,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/block-icon', () => ({
|
||||
default: ({ type }: { type: string }) => (
|
||||
<div data-testid="block-icon" data-type={type}>BlockIcon</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/switch', () => ({
|
||||
default: ({ defaultValue, onChange, disabled }: { defaultValue: boolean, onChange: (v: boolean) => void, disabled: boolean }) => (
|
||||
<button
|
||||
data-testid="switch"
|
||||
data-checked={defaultValue ? 'true' : 'false'}
|
||||
data-disabled={disabled ? 'true' : 'false'}
|
||||
onClick={() => onChange(!defaultValue)}
|
||||
>
|
||||
Switch
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('TriggerCard', () => {
|
||||
const mockAppInfo = {
|
||||
id: 'test-app-id',
|
||||
name: 'Test App',
|
||||
description: 'Test description',
|
||||
mode: AppModeEnum.WORKFLOW,
|
||||
icon_type: 'emoji',
|
||||
icon: 'test-icon',
|
||||
icon_background: '#ffffff',
|
||||
created_at: Date.now(),
|
||||
updated_at: Date.now(),
|
||||
enable_site: true,
|
||||
enable_api: true,
|
||||
} as AppDetailResponse
|
||||
|
||||
const mockOnToggleResult = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTriggers = []
|
||||
mockIsLoading = false
|
||||
mockUpdateTriggerStatus.mockResolvedValue({})
|
||||
})
|
||||
|
||||
describe('Loading State', () => {
|
||||
it('should render loading skeleton when isLoading is true', () => {
|
||||
mockIsLoading = true
|
||||
|
||||
const { container } = render(
|
||||
<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />,
|
||||
)
|
||||
|
||||
expect(container.querySelector('.animate-pulse')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Empty State', () => {
|
||||
it('should show no triggers added message when triggers is empty', () => {
|
||||
mockTriggers = []
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(screen.getByText('overview.triggerInfo.noTriggerAdded')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show trigger status description when no triggers', () => {
|
||||
mockTriggers = []
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(screen.getByText('overview.triggerInfo.triggerStatusDescription')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show learn more link when no triggers', () => {
|
||||
mockTriggers = []
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const learnMoreLink = screen.getByText('overview.triggerInfo.learnAboutTriggers')
|
||||
expect(learnMoreLink).toBeInTheDocument()
|
||||
expect(learnMoreLink).toHaveAttribute('href', 'https://docs.example.com/use-dify/nodes/trigger/overview')
|
||||
})
|
||||
})
|
||||
|
||||
describe('With Triggers', () => {
|
||||
beforeEach(() => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Webhook Trigger',
|
||||
trigger_type: 'trigger-webhook',
|
||||
status: 'enabled',
|
||||
},
|
||||
{
|
||||
id: 'trigger-2',
|
||||
node_id: 'node-2',
|
||||
title: 'Schedule Trigger',
|
||||
trigger_type: 'trigger-schedule',
|
||||
status: 'disabled',
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
it('should show triggers count message', () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(screen.getByText('overview.triggerInfo.triggersAdded (2)')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render trigger titles', () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(screen.getByText('Webhook Trigger')).toBeInTheDocument()
|
||||
expect(screen.getByText('Schedule Trigger')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show running status for enabled triggers', () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(screen.getByText('overview.status.running')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show disable status for disabled triggers', () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(screen.getByText('overview.status.disable')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render block icons for each trigger', () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const blockIcons = screen.getAllByTestId('block-icon')
|
||||
expect(blockIcons.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should render switches for each trigger', () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switches = screen.getAllByTestId('switch')
|
||||
expect(switches.length).toBe(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Toggle Trigger', () => {
|
||||
beforeEach(() => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Test Trigger',
|
||||
trigger_type: 'trigger-webhook',
|
||||
status: 'disabled',
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
it('should call updateTriggerStatus when toggle is clicked', async () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switchBtn = screen.getByTestId('switch')
|
||||
fireEvent.click(switchBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateTriggerStatus).toHaveBeenCalledWith({
|
||||
appId: 'test-app-id',
|
||||
triggerId: 'trigger-1',
|
||||
enableTrigger: true,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should update trigger status in store optimistically', async () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switchBtn = screen.getByTestId('switch')
|
||||
fireEvent.click(switchBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetTriggerStatus).toHaveBeenCalledWith('node-1', 'enabled')
|
||||
})
|
||||
})
|
||||
|
||||
it('should invalidate app triggers after successful update', async () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switchBtn = screen.getByTestId('switch')
|
||||
fireEvent.click(switchBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockInvalidateAppTriggers).toHaveBeenCalledWith('test-app-id')
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onToggleResult with null on success', async () => {
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switchBtn = screen.getByTestId('switch')
|
||||
fireEvent.click(switchBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnToggleResult).toHaveBeenCalledWith(null)
|
||||
})
|
||||
})
|
||||
|
||||
it('should rollback status and call onToggleResult with error on failure', async () => {
|
||||
const error = new Error('Update failed')
|
||||
mockUpdateTriggerStatus.mockRejectedValueOnce(error)
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switchBtn = screen.getByTestId('switch')
|
||||
fireEvent.click(switchBtn)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetTriggerStatus).toHaveBeenCalledWith('node-1', 'disabled')
|
||||
expect(mockOnToggleResult).toHaveBeenCalledWith(error)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Trigger Types', () => {
|
||||
it('should render webhook trigger type correctly', () => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Webhook',
|
||||
trigger_type: 'trigger-webhook',
|
||||
status: 'enabled',
|
||||
},
|
||||
]
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const blockIcon = screen.getByTestId('block-icon')
|
||||
expect(blockIcon).toHaveAttribute('data-type', 'trigger-webhook')
|
||||
})
|
||||
|
||||
it('should render schedule trigger type correctly', () => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Schedule',
|
||||
trigger_type: 'trigger-schedule',
|
||||
status: 'enabled',
|
||||
},
|
||||
]
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const blockIcon = screen.getByTestId('block-icon')
|
||||
expect(blockIcon).toHaveAttribute('data-type', 'trigger-schedule')
|
||||
})
|
||||
|
||||
it('should render plugin trigger type correctly', () => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Plugin',
|
||||
trigger_type: 'trigger-plugin',
|
||||
status: 'enabled',
|
||||
provider_name: 'plugin-1',
|
||||
},
|
||||
]
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const blockIcon = screen.getByTestId('block-icon')
|
||||
expect(blockIcon).toHaveAttribute('data-type', 'trigger-plugin')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Editor Permissions', () => {
|
||||
it('should render switches for triggers', () => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Test Trigger',
|
||||
trigger_type: 'trigger-webhook',
|
||||
status: 'enabled',
|
||||
},
|
||||
]
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
const switchBtn = screen.getByTestId('switch')
|
||||
expect(switchBtn).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Status Sync', () => {
|
||||
it('should sync trigger statuses to store when data loads', () => {
|
||||
mockTriggers = [
|
||||
{
|
||||
id: 'trigger-1',
|
||||
node_id: 'node-1',
|
||||
title: 'Test',
|
||||
trigger_type: 'trigger-webhook',
|
||||
status: 'enabled',
|
||||
},
|
||||
{
|
||||
id: 'trigger-2',
|
||||
node_id: 'node-2',
|
||||
title: 'Test 2',
|
||||
trigger_type: 'trigger-schedule',
|
||||
status: 'disabled',
|
||||
},
|
||||
]
|
||||
|
||||
render(<TriggerCard appInfo={mockAppInfo} onToggleResult={mockOnToggleResult} />)
|
||||
|
||||
expect(mockSetTriggerStatuses).toHaveBeenCalledWith({
|
||||
'node-1': 'enabled',
|
||||
'node-2': 'disabled',
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -33,6 +33,7 @@ import { fetchWorkflowDraft } from '@/service/workflow'
|
|||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
import { formatTime } from '@/utils/time'
|
||||
import { basePath } from '@/utils/var'
|
||||
|
||||
|
|
@ -161,13 +162,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
appID: app.id,
|
||||
include,
|
||||
})
|
||||
const a = document.createElement('a')
|
||||
const file = new Blob([data], { type: 'application/yaml' })
|
||||
const url = URL.createObjectURL(file)
|
||||
a.href = url
|
||||
a.download = `${app.name}.yml`
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
downloadBlob({ data: file, fileName: `${app.name}.yml` })
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) })
|
||||
|
|
@ -346,7 +342,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
dateFormat: `${t('segment.dateTimeFormat', { ns: 'datasetDocuments' })}`,
|
||||
})
|
||||
return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${timeText}`
|
||||
}, [app.updated_at, app.created_at])
|
||||
}, [app.updated_at, app.created_at, t])
|
||||
|
||||
return (
|
||||
<>
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ import ImagePreview from '@/app/components/base/image-uploader/image-preview'
|
|||
import ProgressCircle from '@/app/components/base/progress-bar/progress-circle'
|
||||
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadUrl } from '@/utils/download'
|
||||
import { formatFileSize } from '@/utils/format'
|
||||
import FileImageRender from '../file-image-render'
|
||||
import FileTypeIcon from '../file-type-icon'
|
||||
import {
|
||||
downloadFile,
|
||||
fileIsUploaded,
|
||||
getFileAppearanceType,
|
||||
getFileExtension,
|
||||
|
|
@ -140,7 +140,7 @@ const FileInAttachmentItem = ({
|
|||
showDownloadAction && (
|
||||
<ActionButton onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
downloadFile(url || base64Url || '', name)
|
||||
downloadUrl({ url: url || base64Url || '', fileName: name, target: '_blank' })
|
||||
}}
|
||||
>
|
||||
<RiDownloadLine className="h-4 w-4" />
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ import Button from '@/app/components/base/button'
|
|||
import { ReplayLine } from '@/app/components/base/icons/src/vender/other'
|
||||
import ImagePreview from '@/app/components/base/image-uploader/image-preview'
|
||||
import ProgressCircle from '@/app/components/base/progress-bar/progress-circle'
|
||||
import { downloadUrl } from '@/utils/download'
|
||||
import FileImageRender from '../file-image-render'
|
||||
import {
|
||||
downloadFile,
|
||||
fileIsUploaded,
|
||||
} from '../utils'
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ const FileImageItem = ({
|
|||
className="absolute bottom-0.5 right-0.5 flex h-6 w-6 items-center justify-center rounded-lg bg-components-actionbar-bg shadow-md"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
downloadFile(download_url || '', name)
|
||||
downloadUrl({ url: download_url || '', fileName: name, target: '_blank' })
|
||||
}}
|
||||
>
|
||||
<RiDownloadLine className="h-4 w-4 text-text-tertiary" />
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ import VideoPreview from '@/app/components/base/file-uploader/video-preview'
|
|||
import { ReplayLine } from '@/app/components/base/icons/src/vender/other'
|
||||
import ProgressCircle from '@/app/components/base/progress-bar/progress-circle'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { downloadUrl } from '@/utils/download'
|
||||
import { formatFileSize } from '@/utils/format'
|
||||
import FileTypeIcon from '../file-type-icon'
|
||||
import {
|
||||
downloadFile,
|
||||
fileIsUploaded,
|
||||
getFileAppearanceType,
|
||||
getFileExtension,
|
||||
|
|
@ -100,7 +100,7 @@ const FileItem = ({
|
|||
className="absolute -right-1 -top-1 hidden group-hover/file-item:flex"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
downloadFile(download_url || '', name)
|
||||
downloadUrl({ url: download_url || '', fileName: name, target: '_blank' })
|
||||
}}
|
||||
>
|
||||
<RiDownloadLine className="h-3.5 w-3.5 text-text-tertiary" />
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import type { MockInstance } from 'vitest'
|
||||
import mime from 'mime'
|
||||
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { upload } from '@/service/base'
|
||||
|
|
@ -6,7 +5,6 @@ import { TransferMethod } from '@/types/app'
|
|||
import { FILE_EXTS } from '../prompt-editor/constants'
|
||||
import { FileAppearanceTypeEnum } from './types'
|
||||
import {
|
||||
downloadFile,
|
||||
fileIsUploaded,
|
||||
fileUpload,
|
||||
getFileAppearanceType,
|
||||
|
|
@ -782,74 +780,4 @@ describe('file-uploader utils', () => {
|
|||
} as any)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadFile', () => {
|
||||
let mockAnchor: HTMLAnchorElement
|
||||
let createElementMock: MockInstance
|
||||
let appendChildMock: MockInstance
|
||||
let removeChildMock: MockInstance
|
||||
|
||||
beforeEach(() => {
|
||||
// Mock createElement and appendChild
|
||||
mockAnchor = {
|
||||
href: '',
|
||||
download: '',
|
||||
style: { display: '' },
|
||||
target: '',
|
||||
title: '',
|
||||
click: vi.fn(),
|
||||
} as unknown as HTMLAnchorElement
|
||||
|
||||
createElementMock = vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor as any)
|
||||
appendChildMock = vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => {
|
||||
return node
|
||||
})
|
||||
removeChildMock = vi.spyOn(document.body, 'removeChild').mockImplementation((node: Node) => {
|
||||
return node
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
it('should create and trigger download with correct attributes', () => {
|
||||
const url = 'https://example.com/test.pdf'
|
||||
const filename = 'test.pdf'
|
||||
|
||||
downloadFile(url, filename)
|
||||
|
||||
// Verify anchor element was created with correct properties
|
||||
expect(createElementMock).toHaveBeenCalledWith('a')
|
||||
expect(mockAnchor.href).toBe(url)
|
||||
expect(mockAnchor.download).toBe(filename)
|
||||
expect(mockAnchor.style.display).toBe('none')
|
||||
expect(mockAnchor.target).toBe('_blank')
|
||||
expect(mockAnchor.title).toBe(filename)
|
||||
|
||||
// Verify DOM operations
|
||||
expect(appendChildMock).toHaveBeenCalledWith(mockAnchor)
|
||||
expect(mockAnchor.click).toHaveBeenCalled()
|
||||
expect(removeChildMock).toHaveBeenCalledWith(mockAnchor)
|
||||
})
|
||||
|
||||
it('should handle empty filename', () => {
|
||||
const url = 'https://example.com/test.pdf'
|
||||
const filename = ''
|
||||
|
||||
downloadFile(url, filename)
|
||||
|
||||
expect(mockAnchor.download).toBe('')
|
||||
expect(mockAnchor.title).toBe('')
|
||||
})
|
||||
|
||||
it('should handle empty url', () => {
|
||||
const url = ''
|
||||
const filename = 'test.pdf'
|
||||
|
||||
downloadFile(url, filename)
|
||||
|
||||
expect(mockAnchor.href).toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -249,15 +249,3 @@ export const fileIsUploaded = (file: FileEntity) => {
|
|||
if (file.transferMethod === TransferMethod.remote_url && file.progress === 100)
|
||||
return true
|
||||
}
|
||||
|
||||
export const downloadFile = (url: string, filename: string) => {
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = url
|
||||
anchor.download = filename
|
||||
anchor.style.display = 'none'
|
||||
anchor.target = '_blank'
|
||||
anchor.title = filename
|
||||
document.body.appendChild(anchor)
|
||||
anchor.click()
|
||||
document.body.removeChild(anchor)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { createPortal } from 'react-dom'
|
|||
import { useHotkeys } from 'react-hotkeys-hook'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { downloadUrl } from '@/utils/download'
|
||||
|
||||
type ImagePreviewProps = {
|
||||
url: string
|
||||
|
|
@ -60,27 +61,14 @@ const ImagePreview: FC<ImagePreviewProps> = ({
|
|||
|
||||
const downloadImage = () => {
|
||||
// Open in a new window, considering the case when the page is inside an iframe
|
||||
if (url.startsWith('http') || url.startsWith('https')) {
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.target = '_blank'
|
||||
a.download = title
|
||||
a.click()
|
||||
}
|
||||
else if (url.startsWith('data:image')) {
|
||||
// Base64 image
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.target = '_blank'
|
||||
a.download = title
|
||||
a.click()
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: `Unable to open image: ${url}`,
|
||||
})
|
||||
if (url.startsWith('http') || url.startsWith('https') || url.startsWith('data:image')) {
|
||||
downloadUrl({ url, fileName: title, target: '_blank' })
|
||||
return
|
||||
}
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: `Unable to open image: ${url}`,
|
||||
})
|
||||
}
|
||||
|
||||
const zoomIn = () => {
|
||||
|
|
@ -135,12 +123,7 @@ const ImagePreview: FC<ImagePreviewProps> = ({
|
|||
catch (err) {
|
||||
console.error('Failed to copy image:', err)
|
||||
|
||||
const link = document.createElement('a')
|
||||
link.href = url
|
||||
link.download = `${title}.png`
|
||||
document.body.appendChild(link)
|
||||
link.click()
|
||||
document.body.removeChild(link)
|
||||
downloadUrl({ url, fileName: `${title}.png` })
|
||||
|
||||
Toast.notify({
|
||||
type: 'info',
|
||||
|
|
@ -215,6 +198,7 @@ const ImagePreview: FC<ImagePreviewProps> = ({
|
|||
tabIndex={-1}
|
||||
>
|
||||
{ }
|
||||
{/* eslint-disable-next-line next/no-img-element */}
|
||||
<img
|
||||
ref={imgRef}
|
||||
alt={title}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { useEffect, useRef, useState } from 'react'
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { downloadUrl } from '@/utils/download'
|
||||
|
||||
type Props = {
|
||||
content: string
|
||||
|
|
@ -40,11 +41,10 @@ const ShareQRCode = ({ content }: Props) => {
|
|||
}, [isShow])
|
||||
|
||||
const downloadQR = () => {
|
||||
const canvas = document.getElementsByTagName('canvas')[0]
|
||||
const link = document.createElement('a')
|
||||
link.download = 'qrcode.png'
|
||||
link.href = canvas.toDataURL()
|
||||
link.click()
|
||||
const canvas = qrCodeRef.current?.querySelector('canvas')
|
||||
if (!(canvas instanceof HTMLCanvasElement))
|
||||
return
|
||||
downloadUrl({ url: canvas.toDataURL(), fileName: 'qrcode.png' })
|
||||
}
|
||||
|
||||
const handlePanelClick = (event: React.MouseEvent) => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
import { render, screen } from '@testing-library/react'
|
||||
import Usage from './usage'
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockPlan = {
|
||||
usage: {
|
||||
annotatedResponse: 50,
|
||||
},
|
||||
total: {
|
||||
annotatedResponse: 100,
|
||||
},
|
||||
}
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
plan: mockPlan,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('Usage', () => {
|
||||
// Rendering: renders UsageInfo with correct props from context
|
||||
describe('Rendering', () => {
|
||||
it('should render usage info with data from provider context', () => {
|
||||
// Arrange & Act
|
||||
render(<Usage />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('annotatedResponse.quotaTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass className to UsageInfo component', () => {
|
||||
// Arrange
|
||||
const testClassName = 'mt-4'
|
||||
|
||||
// Act
|
||||
const { container } = render(<Usage className={testClassName} />)
|
||||
|
||||
// Assert
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveClass(testClassName)
|
||||
})
|
||||
|
||||
it('should display usage and total values from context', () => {
|
||||
// Arrange & Act
|
||||
render(<Usage />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('50')).toBeInTheDocument()
|
||||
expect(screen.getByText('100')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -73,6 +73,56 @@ describe('Billing', () => {
|
|||
})
|
||||
})
|
||||
|
||||
it('returns the refetched url from the async callback', async () => {
|
||||
const newUrl = 'https://new-billing-url'
|
||||
refetchMock.mockResolvedValue({ data: newUrl })
|
||||
render(<Billing />)
|
||||
|
||||
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
|
||||
const [asyncCallback] = openAsyncWindowMock.mock.calls[0]
|
||||
|
||||
// Execute the async callback passed to openAsyncWindow
|
||||
const result = await asyncCallback()
|
||||
expect(result).toBe(newUrl)
|
||||
expect(refetchMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns null when refetch returns no url', async () => {
|
||||
refetchMock.mockResolvedValue({ data: null })
|
||||
render(<Billing />)
|
||||
|
||||
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
|
||||
const [asyncCallback] = openAsyncWindowMock.mock.calls[0]
|
||||
|
||||
// Execute the async callback when url is null
|
||||
const result = await asyncCallback()
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('handles errors in onError callback', async () => {
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
render(<Billing />)
|
||||
|
||||
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
|
||||
const [, options] = openAsyncWindowMock.mock.calls[0]
|
||||
|
||||
// Execute the onError callback
|
||||
const testError = new Error('Test error')
|
||||
options.onError(testError)
|
||||
expect(consoleError).toHaveBeenCalledWith('Failed to fetch billing url', testError)
|
||||
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('disables the button while billing url is fetching', () => {
|
||||
fetching = true
|
||||
render(<Billing />)
|
||||
|
|
|
|||
|
|
@ -125,4 +125,70 @@ describe('PlanComp', () => {
|
|||
|
||||
expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
it('does not trigger verify when isPending is true', async () => {
|
||||
isPending = true
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
const verifyBtn = screen.getByText('education.toVerified')
|
||||
fireEvent.click(verifyBtn)
|
||||
|
||||
await waitFor(() => expect(mutateAsyncMock).not.toHaveBeenCalled())
|
||||
})
|
||||
|
||||
it('renders sandbox plan', () => {
|
||||
providerContextMock.mockReturnValue({
|
||||
plan: { ...planMock, type: Plan.sandbox },
|
||||
enableEducationPlan: false,
|
||||
allowRefreshEducationVerify: false,
|
||||
isEducationAccount: false,
|
||||
})
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
expect(screen.getByText('billing.plans.sandbox.name')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders team plan', () => {
|
||||
providerContextMock.mockReturnValue({
|
||||
plan: { ...planMock, type: Plan.team },
|
||||
enableEducationPlan: false,
|
||||
allowRefreshEducationVerify: false,
|
||||
isEducationAccount: false,
|
||||
})
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
expect(screen.getByText('billing.plans.team.name')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows verify button when education account is about to expire', () => {
|
||||
providerContextMock.mockReturnValue({
|
||||
plan: planMock,
|
||||
enableEducationPlan: true,
|
||||
allowRefreshEducationVerify: true,
|
||||
isEducationAccount: true,
|
||||
})
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
expect(screen.getByText('education.toVerified')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles modal onConfirm and onCancel callbacks', async () => {
|
||||
mutateAsyncMock.mockRejectedValueOnce(new Error('boom'))
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
// Trigger verify to show modal
|
||||
const verifyBtn = screen.getByText('education.toVerified')
|
||||
fireEvent.click(verifyBtn)
|
||||
|
||||
await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true'))
|
||||
|
||||
// Get the props passed to the modal and call onConfirm/onCancel
|
||||
const lastCall = verifyStateModalMock.mock.calls[verifyStateModalMock.mock.calls.length - 1][0]
|
||||
expect(lastCall.onConfirm).toBeDefined()
|
||||
expect(lastCall.onCancel).toBeDefined()
|
||||
|
||||
// Call onConfirm to close modal
|
||||
lastCall.onConfirm()
|
||||
lastCall.onCancel()
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -52,6 +52,24 @@ describe('Pricing Assets', () => {
|
|||
expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-saas-dify-blue-accessible)')).toBe(true)
|
||||
})
|
||||
|
||||
it('should render inactive state for Cloud', () => {
|
||||
// Arrange
|
||||
const { container } = render(<Cloud isActive={false} />)
|
||||
|
||||
// Assert
|
||||
const rects = Array.from(container.querySelectorAll('rect'))
|
||||
expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-text-primary)')).toBe(true)
|
||||
})
|
||||
|
||||
it('should render active state for SelfHosted', () => {
|
||||
// Arrange
|
||||
const { container } = render(<SelfHosted isActive />)
|
||||
|
||||
// Assert
|
||||
const rects = Array.from(container.querySelectorAll('rect'))
|
||||
expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-saas-dify-blue-accessible)')).toBe(true)
|
||||
})
|
||||
|
||||
it('should render inactive state for SelfHosted', () => {
|
||||
// Arrange
|
||||
const { container } = render(<SelfHosted isActive={false} />)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,301 @@
|
|||
import type { CurrentPlanInfoBackend } from '../type'
|
||||
import { DocumentProcessingPriority, Plan } from '../type'
|
||||
import { getPlanVectorSpaceLimitMB, parseCurrentPlan, parseVectorSpaceToMB } from './index'
|
||||
|
||||
describe('billing utils', () => {
|
||||
// parseVectorSpaceToMB tests
|
||||
describe('parseVectorSpaceToMB', () => {
|
||||
it('should parse MB values correctly', () => {
|
||||
expect(parseVectorSpaceToMB('50MB')).toBe(50)
|
||||
expect(parseVectorSpaceToMB('100MB')).toBe(100)
|
||||
})
|
||||
|
||||
it('should parse GB values and convert to MB', () => {
|
||||
expect(parseVectorSpaceToMB('5GB')).toBe(5 * 1024)
|
||||
expect(parseVectorSpaceToMB('20GB')).toBe(20 * 1024)
|
||||
})
|
||||
|
||||
it('should be case insensitive', () => {
|
||||
expect(parseVectorSpaceToMB('50mb')).toBe(50)
|
||||
expect(parseVectorSpaceToMB('5gb')).toBe(5 * 1024)
|
||||
})
|
||||
|
||||
it('should return 0 for invalid format', () => {
|
||||
expect(parseVectorSpaceToMB('50')).toBe(0)
|
||||
expect(parseVectorSpaceToMB('invalid')).toBe(0)
|
||||
expect(parseVectorSpaceToMB('')).toBe(0)
|
||||
expect(parseVectorSpaceToMB('50TB')).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
// getPlanVectorSpaceLimitMB tests
|
||||
describe('getPlanVectorSpaceLimitMB', () => {
|
||||
it('should return correct vector space for sandbox plan', () => {
|
||||
expect(getPlanVectorSpaceLimitMB(Plan.sandbox)).toBe(50)
|
||||
})
|
||||
|
||||
it('should return correct vector space for professional plan', () => {
|
||||
expect(getPlanVectorSpaceLimitMB(Plan.professional)).toBe(5 * 1024)
|
||||
})
|
||||
|
||||
it('should return correct vector space for team plan', () => {
|
||||
expect(getPlanVectorSpaceLimitMB(Plan.team)).toBe(20 * 1024)
|
||||
})
|
||||
|
||||
it('should return 0 for invalid plan', () => {
|
||||
// @ts-expect-error - Testing invalid plan input
|
||||
expect(getPlanVectorSpaceLimitMB('invalid')).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
// parseCurrentPlan tests
|
||||
describe('parseCurrentPlan', () => {
|
||||
const createMockPlanData = (overrides: Partial<CurrentPlanInfoBackend> = {}): CurrentPlanInfoBackend => ({
|
||||
billing: {
|
||||
enabled: true,
|
||||
subscription: {
|
||||
plan: Plan.sandbox,
|
||||
},
|
||||
},
|
||||
members: {
|
||||
size: 1,
|
||||
limit: 1,
|
||||
},
|
||||
apps: {
|
||||
size: 2,
|
||||
limit: 5,
|
||||
},
|
||||
vector_space: {
|
||||
size: 10,
|
||||
limit: 50,
|
||||
},
|
||||
annotation_quota_limit: {
|
||||
size: 5,
|
||||
limit: 10,
|
||||
},
|
||||
documents_upload_quota: {
|
||||
size: 20,
|
||||
limit: 0,
|
||||
},
|
||||
docs_processing: DocumentProcessingPriority.standard,
|
||||
can_replace_logo: false,
|
||||
model_load_balancing_enabled: false,
|
||||
dataset_operator_enabled: false,
|
||||
education: {
|
||||
enabled: false,
|
||||
activated: false,
|
||||
},
|
||||
webapp_copyright_enabled: false,
|
||||
workspace_members: {
|
||||
size: 1,
|
||||
limit: 1,
|
||||
},
|
||||
is_allow_transfer_workspace: false,
|
||||
knowledge_pipeline: {
|
||||
publish_enabled: false,
|
||||
},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
it('should parse plan type correctly', () => {
|
||||
const data = createMockPlanData()
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.type).toBe(Plan.sandbox)
|
||||
})
|
||||
|
||||
it('should parse usage values correctly', () => {
|
||||
const data = createMockPlanData()
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.usage.vectorSpace).toBe(10)
|
||||
expect(result.usage.buildApps).toBe(2)
|
||||
expect(result.usage.teamMembers).toBe(1)
|
||||
expect(result.usage.annotatedResponse).toBe(5)
|
||||
expect(result.usage.documentsUploadQuota).toBe(20)
|
||||
})
|
||||
|
||||
it('should parse total limits correctly', () => {
|
||||
const data = createMockPlanData()
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.total.vectorSpace).toBe(50)
|
||||
expect(result.total.buildApps).toBe(5)
|
||||
expect(result.total.teamMembers).toBe(1)
|
||||
expect(result.total.annotatedResponse).toBe(10)
|
||||
})
|
||||
|
||||
it('should convert 0 limits to NUM_INFINITE (-1)', () => {
|
||||
const data = createMockPlanData({
|
||||
documents_upload_quota: {
|
||||
size: 20,
|
||||
limit: 0,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.total.documentsUploadQuota).toBe(-1)
|
||||
})
|
||||
|
||||
it('should handle api_rate_limit quota', () => {
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: null,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.usage.apiRateLimit).toBe(100)
|
||||
expect(result.total.apiRateLimit).toBe(5000)
|
||||
})
|
||||
|
||||
it('should handle trigger_event quota', () => {
|
||||
const data = createMockPlanData({
|
||||
trigger_event: {
|
||||
usage: 50,
|
||||
limit: 3000,
|
||||
reset_date: null,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.usage.triggerEvents).toBe(50)
|
||||
expect(result.total.triggerEvents).toBe(3000)
|
||||
})
|
||||
|
||||
it('should use fallback for api_rate_limit when not provided', () => {
|
||||
const data = createMockPlanData()
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
// Fallback to plan preset value for sandbox: 5000
|
||||
expect(result.total.apiRateLimit).toBe(5000)
|
||||
})
|
||||
|
||||
it('should convert 0 or -1 rate limits to NUM_INFINITE', () => {
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 0,
|
||||
limit: 0,
|
||||
reset_date: null,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.total.apiRateLimit).toBe(-1)
|
||||
|
||||
const data2 = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 0,
|
||||
limit: -1,
|
||||
reset_date: null,
|
||||
},
|
||||
})
|
||||
const result2 = parseCurrentPlan(data2)
|
||||
expect(result2.total.apiRateLimit).toBe(-1)
|
||||
})
|
||||
|
||||
it('should handle reset dates with milliseconds timestamp', () => {
|
||||
const futureDate = Date.now() + 86400000 // Tomorrow in ms
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: futureDate,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.reset.apiRateLimit).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle reset dates with seconds timestamp', () => {
|
||||
const futureDate = Math.floor(Date.now() / 1000) + 86400 // Tomorrow in seconds
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: futureDate,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.reset.apiRateLimit).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle reset dates in YYYYMMDD format', () => {
|
||||
const tomorrow = new Date()
|
||||
tomorrow.setDate(tomorrow.getDate() + 1)
|
||||
const year = tomorrow.getFullYear()
|
||||
const month = String(tomorrow.getMonth() + 1).padStart(2, '0')
|
||||
const day = String(tomorrow.getDate()).padStart(2, '0')
|
||||
const dateNumber = Number.parseInt(`${year}${month}${day}`, 10)
|
||||
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: dateNumber,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.reset.apiRateLimit).toBe(1)
|
||||
})
|
||||
|
||||
it('should return null for invalid reset dates', () => {
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: 0,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.reset.apiRateLimit).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null for negative reset dates', () => {
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: -1,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.reset.apiRateLimit).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null when reset date is in the past', () => {
|
||||
const pastDate = Date.now() - 86400000 // Yesterday
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: pastDate,
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.reset.apiRateLimit).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle missing apps field', () => {
|
||||
const data = createMockPlanData()
|
||||
// @ts-expect-error - Testing edge case
|
||||
delete data.apps
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.usage.buildApps).toBe(0)
|
||||
})
|
||||
|
||||
it('should return null for unrecognized date format', () => {
|
||||
const data = createMockPlanData({
|
||||
api_rate_limit: {
|
||||
usage: 100,
|
||||
limit: 5000,
|
||||
reset_date: 12345, // Unrecognized format
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
expect(result.reset.apiRateLimit).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import { cleanup, render, screen } from '@testing-library/react'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import ApiIndex from './index'
|
||||
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('ApiIndex', () => {
|
||||
it('should render without crashing', () => {
|
||||
render(<ApiIndex />)
|
||||
expect(screen.getByText('index')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render a div with text "index"', () => {
|
||||
const { container } = render(<ApiIndex />)
|
||||
expect(container.firstChild).toBeInstanceOf(HTMLDivElement)
|
||||
expect(container.textContent).toBe('index')
|
||||
})
|
||||
|
||||
it('should be a valid function component', () => {
|
||||
expect(typeof ApiIndex).toBe('function')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
import { cleanup, render, screen } from '@testing-library/react'
|
||||
import { afterEach, describe, expect, it } from 'vitest'
|
||||
import { ChunkContainer, ChunkLabel, QAPreview } from './chunk'
|
||||
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('ChunkLabel', () => {
|
||||
it('should render label text', () => {
|
||||
render(<ChunkLabel label="Chunk 1" characterCount={100} />)
|
||||
expect(screen.getByText('Chunk 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render character count', () => {
|
||||
render(<ChunkLabel label="Chunk 1" characterCount={150} />)
|
||||
expect(screen.getByText('150 characters')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render separator dot', () => {
|
||||
render(<ChunkLabel label="Chunk 1" characterCount={100} />)
|
||||
expect(screen.getByText('·')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with zero character count', () => {
|
||||
render(<ChunkLabel label="Empty Chunk" characterCount={0} />)
|
||||
expect(screen.getByText('0 characters')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with large character count', () => {
|
||||
render(<ChunkLabel label="Large Chunk" characterCount={999999} />)
|
||||
expect(screen.getByText('999999 characters')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ChunkContainer', () => {
|
||||
it('should render label and character count', () => {
|
||||
render(<ChunkContainer label="Container 1" characterCount={200}>Content</ChunkContainer>)
|
||||
expect(screen.getByText('Container 1')).toBeInTheDocument()
|
||||
expect(screen.getByText('200 characters')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render children content', () => {
|
||||
render(<ChunkContainer label="Container 1" characterCount={200}>Test Content</ChunkContainer>)
|
||||
expect(screen.getByText('Test Content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with complex children', () => {
|
||||
render(
|
||||
<ChunkContainer label="Container" characterCount={100}>
|
||||
<div data-testid="child-div">
|
||||
<span>Nested content</span>
|
||||
</div>
|
||||
</ChunkContainer>,
|
||||
)
|
||||
expect(screen.getByTestId('child-div')).toBeInTheDocument()
|
||||
expect(screen.getByText('Nested content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render empty children', () => {
|
||||
render(<ChunkContainer label="Empty" characterCount={0}>{null}</ChunkContainer>)
|
||||
expect(screen.getByText('Empty')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('QAPreview', () => {
|
||||
const mockQA = {
|
||||
question: 'What is the meaning of life?',
|
||||
answer: 'The meaning of life is 42.',
|
||||
}
|
||||
|
||||
it('should render question text', () => {
|
||||
render(<QAPreview qa={mockQA} />)
|
||||
expect(screen.getByText('What is the meaning of life?')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render answer text', () => {
|
||||
render(<QAPreview qa={mockQA} />)
|
||||
expect(screen.getByText('The meaning of life is 42.')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render Q label', () => {
|
||||
render(<QAPreview qa={mockQA} />)
|
||||
expect(screen.getByText('Q')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render A label', () => {
|
||||
render(<QAPreview qa={mockQA} />)
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with empty strings', () => {
|
||||
render(<QAPreview qa={{ question: '', answer: '' }} />)
|
||||
expect(screen.getByText('Q')).toBeInTheDocument()
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with long text', () => {
|
||||
const longQuestion = 'Q'.repeat(500)
|
||||
const longAnswer = 'A'.repeat(500)
|
||||
render(<QAPreview qa={{ question: longQuestion, answer: longAnswer }} />)
|
||||
expect(screen.getByText(longQuestion)).toBeInTheDocument()
|
||||
expect(screen.getByText(longAnswer)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with special characters', () => {
|
||||
render(<QAPreview qa={{ question: 'What about <script>?', answer: '& special chars!' }} />)
|
||||
expect(screen.getByText('What about <script>?')).toBeInTheDocument()
|
||||
expect(screen.getByText('& special chars!')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import type { ErrorDocsResponse } from '@/models/datasets'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { retryErrorDocs } from '@/service/datasets'
|
||||
import { useDatasetErrorDocs } from '@/service/knowledge/use-dataset'
|
||||
import RetryButton from './index-failed'
|
||||
|
|
@ -19,6 +19,11 @@ vi.mock('@/service/datasets', () => ({
|
|||
const mockUseDatasetErrorDocs = vi.mocked(useDatasetErrorDocs)
|
||||
const mockRetryErrorDocs = vi.mocked(retryErrorDocs)
|
||||
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Helper to create mock query result
|
||||
const createMockQueryResult = (
|
||||
data: ErrorDocsResponse | undefined,
|
||||
|
|
@ -139,6 +144,11 @@ describe('RetryButton (IndexFailed)', () => {
|
|||
document_ids: ['doc1', 'doc2'],
|
||||
})
|
||||
})
|
||||
|
||||
// Wait for all state updates to complete
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should refetch error docs after successful retry', async () => {
|
||||
|
|
@ -169,8 +179,10 @@ describe('RetryButton (IndexFailed)', () => {
|
|||
}, false),
|
||||
)
|
||||
|
||||
// Delay the response to test loading state
|
||||
mockRetryErrorDocs.mockImplementation(() => new Promise(resolve => setTimeout(() => resolve({ result: 'success' }), 100)))
|
||||
let resolveRetry: ((value: { result: 'success' }) => void) | undefined
|
||||
mockRetryErrorDocs.mockImplementation(() => new Promise((resolve) => {
|
||||
resolveRetry = resolve
|
||||
}))
|
||||
|
||||
render(<RetryButton datasetId="test-dataset" />)
|
||||
|
||||
|
|
@ -183,6 +195,11 @@ describe('RetryButton (IndexFailed)', () => {
|
|||
expect(button).toHaveClass('cursor-not-allowed')
|
||||
expect(button).toHaveClass('text-text-disabled')
|
||||
})
|
||||
|
||||
resolveRetry?.({ result: 'success' })
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
|
@ -202,8 +219,13 @@ describe('RetryButton (IndexFailed)', () => {
|
|||
const retryButton = screen.getByText(/retry/i)
|
||||
fireEvent.click(retryButton)
|
||||
|
||||
// Wait for retry to complete and state to update
|
||||
await waitFor(() => {
|
||||
expect(mockRetryErrorDocs).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Button should still be visible after failed retry
|
||||
await waitFor(() => {
|
||||
// Button should still be visible after failed retry
|
||||
expect(screen.getByText(/retry/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
|
@ -275,6 +297,11 @@ describe('RetryButton (IndexFailed)', () => {
|
|||
document_ids: [],
|
||||
})
|
||||
})
|
||||
|
||||
// Wait for all state updates to complete
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -23,9 +23,10 @@ vi.mock('@/app/components/base/toast', () => ({
|
|||
},
|
||||
}))
|
||||
|
||||
// Mock downloadFile utility
|
||||
vi.mock('@/utils/format', () => ({
|
||||
downloadFile: vi.fn(),
|
||||
// Mock download utilities
|
||||
vi.mock('@/utils/download', () => ({
|
||||
downloadBlob: vi.fn(),
|
||||
downloadUrl: vi.fn(),
|
||||
}))
|
||||
|
||||
// Capture Confirm callbacks
|
||||
|
|
@ -502,8 +503,8 @@ describe('TemplateCard', () => {
|
|||
})
|
||||
})
|
||||
|
||||
it('should call downloadFile on successful export', async () => {
|
||||
const { downloadFile } = await import('@/utils/format')
|
||||
it('should call downloadBlob on successful export', async () => {
|
||||
const { downloadBlob } = await import('@/utils/download')
|
||||
mockExportPipelineDSL.mockImplementation((_id, callbacks) => {
|
||||
callbacks.onSuccess({ data: 'yaml_content' })
|
||||
return Promise.resolve()
|
||||
|
|
@ -514,7 +515,7 @@ describe('TemplateCard', () => {
|
|||
fireEvent.click(exportButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(downloadFile).toHaveBeenCalledWith(expect.objectContaining({
|
||||
expect(downloadBlob).toHaveBeenCalledWith(expect.objectContaining({
|
||||
fileName: 'Test Pipeline.pipeline',
|
||||
}))
|
||||
})
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import {
|
|||
useInvalidCustomizedTemplateList,
|
||||
usePipelineTemplateById,
|
||||
} from '@/service/use-pipeline'
|
||||
import { downloadFile } from '@/utils/format'
|
||||
import { downloadBlob } from '@/utils/download'
|
||||
import Actions from './actions'
|
||||
import Content from './content'
|
||||
import Details from './details'
|
||||
|
|
@ -108,10 +108,7 @@ const TemplateCard = ({
|
|||
await exportPipelineDSL(pipeline.id, {
|
||||
onSuccess: (res) => {
|
||||
const blob = new Blob([res.data], { type: 'application/yaml' })
|
||||
downloadFile({
|
||||
data: blob,
|
||||
fileName: `${pipeline.name}.pipeline`,
|
||||
})
|
||||
downloadBlob({ data: blob, fileName: `${pipeline.name}.pipeline` })
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('exportDSL.successTip', { ns: 'datasetPipeline' }),
|
||||
|
|
|
|||
|
|
@ -125,11 +125,25 @@ const WaterCrawl: FC<Props> = ({
|
|||
await sleep(2500)
|
||||
return await waitForCrawlFinished(jobId)
|
||||
}
|
||||
catch (e: any) {
|
||||
const errorBody = await e.json()
|
||||
catch (error: unknown) {
|
||||
let errorMessage = ''
|
||||
|
||||
const maybeErrorWithJson = error as { json?: () => Promise<unknown>, message?: unknown } | null
|
||||
if (maybeErrorWithJson?.json) {
|
||||
try {
|
||||
const errorBody = await maybeErrorWithJson.json() as { message?: unknown } | null
|
||||
if (typeof errorBody?.message === 'string')
|
||||
errorMessage = errorBody.message
|
||||
}
|
||||
catch {}
|
||||
}
|
||||
|
||||
if (!errorMessage && typeof maybeErrorWithJson?.message === 'string')
|
||||
errorMessage = maybeErrorWithJson.message
|
||||
|
||||
return {
|
||||
isError: true,
|
||||
errorMessage: errorBody.message,
|
||||
errorMessage,
|
||||
data: {
|
||||
data: [],
|
||||
},
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue