Merge remote-tracking branch 'origin/feat/support-agent-sandbox' into feat/support-agent-sandbox

This commit is contained in:
zhsama 2026-01-29 23:53:34 +08:00
commit 1a51f52061
199 changed files with 6616 additions and 2628 deletions

View File

@ -480,4 +480,4 @@ const useButtonState = () => {
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification
- `web/docs/test.md` - Testing specification

View File

@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
@ -309,7 +309,7 @@ For more detailed information, refer to:
### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
- **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase

View File

@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals.
| Scope | Rule |
|-------|------|

View File

@ -72,6 +72,7 @@ jobs:
OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
uv run --project api pytest \
-n auto \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \

View File

@ -47,13 +47,9 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Basedpyright Checks
- name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check
- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
run: make type-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv
The codebase is split into:
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Backend Workflow
@ -18,36 +18,7 @@ The codebase is split into:
## Frontend Workflow
```bash
cd web
pnpm lint:fix
pnpm type-check:tsgo
pnpm test
```
### Frontend Linting
ESLint is used for frontend code quality. Available commands:
```bash
# Lint all files (report only)
pnpm lint
# Lint and auto-fix issues
pnpm lint:fix
# Lint specific files or directories
pnpm lint:fix app/components/base/button/
pnpm lint:fix app/components/base/button/index.tsx
# Lint quietly (errors only, no warnings)
pnpm lint:quiet
# Check code complexity
pnpm lint:complexity
```
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
- Read `web/AGENTS.md` for details
## Testing & Quality Practices

View File

@ -77,7 +77,7 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend

View File

@ -68,9 +68,11 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type check with basedpyright..."
@uv run --directory api --dev basedpyright
@echo "✅ Type check complete"
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@cd api && uv run ty check
@echo "✅ Type checks complete"
test:
@echo "🧪 Running backend unit tests..."
@ -78,7 +80,7 @@ test:
echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \
else \
uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
fi
@echo "✅ Tests complete"
@ -130,7 +132,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@ -620,6 +620,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration

View File

@ -227,6 +227,9 @@ ignore_imports =
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
@ -300,6 +303,58 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> services
core.workflow.nodes.tool.tool_node -> services
[importlinter:contract:model-runtime-no-internal-imports]
name = Model Runtime Internal Imports
type = forbidden
source_modules =
core.model_runtime
forbidden_modules =
configs
controllers
extensions
models
services
tasks
core.agent
core.app
core.base
core.callback_handler
core.datasource
core.db
core.entities
core.errors
core.extension
core.external_data_tool
core.file
core.helper
core.hosting_configuration
core.indexing_runner
core.llm_generator
core.logging
core.mcp
core.memory
core.model_manager
core.moderation
core.ops
core.plugin
core.prompt
core.provider_manager
core.rag
core.repositories
core.schemas
core.tools
core.trigger
core.variables
core.workflow
ignore_imports =
core.model_runtime.model_providers.__base.ai_model -> configs
core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
core.model_runtime.model_providers.__base.large_language_model -> configs
core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
core.model_runtime.model_providers.model_provider_factory -> configs
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
[importlinter:contract:rsc]
name = RSC
type = layers

View File

@ -243,6 +243,11 @@ class PluginConfig(BaseSettings):
default=15728640 * 12,
)
PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching plugin model schemas in Redis",
default=60 * 60,
)
class CliApiConfig(BaseSettings):
"""

View File

@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)

View File

@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource):
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content=content,
content={
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
},
link=payload.link,
sort=payload.sort,
language=payload.language,

View File

@ -51,7 +51,7 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
class AppImportBundlePayload(BaseModel):
class AppImportBundleConfirmPayload(BaseModel):
name: str | None = None
description: str | None = None
icon_type: str | None = None
@ -149,15 +149,38 @@ class AppImportCheckDependenciesApi(Resource):
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports-bundle")
class AppImportBundleApi(Resource):
@console_ns.route("/apps/imports-bundle/prepare")
class AppImportBundlePrepareApi(Resource):
"""Step 1: Get upload URL for bundle import."""
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
from services.app_bundle_service import AppBundleService
current_user, current_tenant_id = current_account_with_tenant()
result = AppBundleService.prepare_import(
tenant_id=current_tenant_id,
account_id=current_user.id,
)
return {"import_id": result.import_id, "upload_url": result.upload_url}, 200
@console_ns.route("/apps/imports-bundle/<string:import_id>/confirm")
class AppImportBundleConfirmApi(Resource):
"""Step 2: Confirm bundle import after upload."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
def post(self, import_id: str):
from flask import request
from core.app.entities.app_bundle_entities import BundleFormatError
@ -165,22 +188,12 @@ class AppImportBundleApi(Resource):
current_user, _ = current_account_with_tenant()
if "file" not in request.files:
return {"error": "No file provided"}, 400
file = request.files["file"]
if not file.filename or not file.filename.endswith(".zip"):
return {"error": "Invalid file format, expected .zip"}, 400
zip_bytes = file.read()
form_data = request.form.to_dict()
args = AppImportBundlePayload.model_validate(form_data)
args = AppImportBundleConfirmPayload.model_validate(request.get_json() or {})
try:
result = AppBundleService.import_bundle(
result = AppBundleService.confirm_import(
import_id=import_id,
account=current_user,
zip_bytes=zip_bytes,
name=args.name,
description=args.description,
icon_type=args.icon_type,

View File

@ -70,9 +70,7 @@ class ContextGeneratePayload(BaseModel):
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
code_context: CodeContextPayload = Field(
description="Existing code node context for incremental generation"
)
code_context: CodeContextPayload = Field(description="Existing code node context for incremental generation")
class SuggestedQuestionsPayload(BaseModel):

View File

@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
@ -288,7 +289,14 @@ class DatasetListApi(Resource):
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict())
# Convert query parameters to dict, handling list parameters correctly
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
if "ids" in request.args:
query_params["ids"] = request.args.getlist("ids")
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = ConsoleDatasetListQuery.model_validate(query_params)
# provider = request.args.get("provider", default="vendor")
if query.ids:
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)

View File

@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
@ -125,6 +130,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentBatchDownloadZipPayload,
)
@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
DocumentService.enrich_documents_with_summary_index_status(
documents=documents,
dataset=dataset,
tenant_id=current_tenant_id,
)
if fetch:
for document in documents:
completed_segments = (
@ -797,6 +810,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@ -832,6 +846,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@ -1255,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
from werkzeug.exceptions import BadRequest
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info("Skipping summary generation for qa_model document %s", document.id)
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id,
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get summary status detail from service
from services.summary_index_service import SummaryIndexService
result = SummaryIndexService.get_document_summary_status_detail(
document_id=document_id,
dataset_id=dataset_id,
)
return result, 200

View File

@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields))
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields))
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": marshal(segments.items, segment_fields),
"data": segments_with_summary,
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@ -1,6 +1,13 @@
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from libs.login import login_required
from .. import console_ns
@ -14,13 +21,45 @@ from ..wraps import (
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@ -15,12 +15,8 @@ api = ExternalApi(
files_ns = Namespace("files", description="File operations", path="/")
from . import (
app_assets_download,
app_assets_upload,
image_preview,
sandbox_archive,
sandbox_file_downloads,
storage_download,
storage_files,
tool_files,
upload,
)
@ -29,14 +25,10 @@ api.add_namespace(files_ns)
__all__ = [
"api",
"app_assets_download",
"app_assets_upload",
"bp",
"files_ns",
"image_preview",
"sandbox_archive",
"sandbox_file_downloads",
"storage_download",
"storage_files",
"tool_files",
"upload",
]

View File

@ -1,77 +0,0 @@
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import files_ns
from core.app_assets.storage import AppAssetSigner, AssetPath
from extensions.ext_storage import storage
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppAssetDownloadQuery(BaseModel):
expires_at: int = Field(..., description="Unix timestamp when the link expires")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
files_ns.schema_model(
AppAssetDownloadQuery.__name__,
AppAssetDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@files_ns.route("/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/download")
@files_ns.route(
"/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/<string:sub_resource_id>/download"
)
class AppAssetDownloadApi(Resource):
def get(
self,
asset_type: str,
tenant_id: str,
app_id: str,
resource_id: str,
sub_resource_id: str | None = None,
):
args = AppAssetDownloadQuery.model_validate(request.args.to_dict(flat=True))
try:
asset_path = AssetPath.from_components(
asset_type=asset_type,
tenant_id=tenant_id,
app_id=app_id,
resource_id=resource_id,
sub_resource_id=sub_resource_id,
)
except ValueError as exc:
raise Forbidden(str(exc)) from exc
if not AppAssetSigner.verify_download_signature(
asset_path=asset_path,
expires_at=args.expires_at,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired download link")
storage_key = asset_path.get_storage_key()
try:
generator = storage.load_stream(storage_key)
except FileNotFoundError as exc:
raise NotFound("File not found") from exc
encoded_filename = quote(storage_key.split("/")[-1])
return Response(
generator,
mimetype="application/octet-stream",
direct_passthrough=True,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
},
)

View File

@ -1,61 +0,0 @@
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.files import files_ns
from core.app_assets.storage import AppAssetSigner, AssetPath
from services.app_asset_service import AppAssetService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppAssetUploadQuery(BaseModel):
expires_at: int = Field(..., description="Unix timestamp when the link expires")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
files_ns.schema_model(
AppAssetUploadQuery.__name__,
AppAssetUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@files_ns.route("/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/upload")
@files_ns.route(
"/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/<string:sub_resource_id>/upload"
)
class AppAssetUploadApi(Resource):
def put(
self,
asset_type: str,
tenant_id: str,
app_id: str,
resource_id: str,
sub_resource_id: str | None = None,
):
args = AppAssetUploadQuery.model_validate(request.args.to_dict(flat=True))
try:
asset_path = AssetPath.from_components(
asset_type=asset_type,
tenant_id=tenant_id,
app_id=app_id,
resource_id=resource_id,
sub_resource_id=sub_resource_id,
)
except ValueError as exc:
raise Forbidden(str(exc)) from exc
if not AppAssetSigner.verify_upload_signature(
asset_path=asset_path,
expires_at=args.expires_at,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired upload link")
content = request.get_data()
AppAssetService.get_storage().save(asset_path, content)
return Response(status=204)

View File

@ -1,76 +0,0 @@
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import files_ns
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
from extensions.ext_storage import storage
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SandboxArchiveQuery(BaseModel):
expires_at: int = Field(..., description="Unix timestamp when the link expires")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
files_ns.schema_model(
SandboxArchiveQuery.__name__,
SandboxArchiveQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@files_ns.route("/sandbox-archives/<string:tenant_id>/<string:sandbox_id>/download")
class SandboxArchiveDownloadApi(Resource):
def get(self, tenant_id: str, sandbox_id: str):
args = SandboxArchiveQuery.model_validate(request.args.to_dict(flat=True))
try:
archive_path = SandboxArchivePath(tenant_id=UUID(tenant_id), sandbox_id=UUID(sandbox_id))
except ValueError as exc:
raise Forbidden(str(exc)) from exc
if not SandboxArchiveSigner.verify_download_signature(
archive_path=archive_path,
expires_at=args.expires_at,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired download link")
try:
generator = storage.load_stream(archive_path.get_storage_key())
except FileNotFoundError as exc:
raise NotFound("Archive not found") from exc
return Response(
generator,
mimetype="application/gzip",
direct_passthrough=True,
)
@files_ns.route("/sandbox-archives/<string:tenant_id>/<string:sandbox_id>/upload")
class SandboxArchiveUploadApi(Resource):
def put(self, tenant_id: str, sandbox_id: str):
args = SandboxArchiveQuery.model_validate(request.args.to_dict(flat=True))
try:
archive_path = SandboxArchivePath(tenant_id=UUID(tenant_id), sandbox_id=UUID(sandbox_id))
except ValueError as exc:
raise Forbidden(str(exc)) from exc
if not SandboxArchiveSigner.verify_upload_signature(
archive_path=archive_path,
expires_at=args.expires_at,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired upload link")
storage.save(archive_path.get_storage_key(), request.get_data())
return Response(status=204)

View File

@ -1,96 +0,0 @@
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import files_ns
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner
from extensions.ext_storage import storage
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SandboxFileDownloadQuery(BaseModel):
expires_at: int = Field(..., description="Unix timestamp when the link expires")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
files_ns.schema_model(
SandboxFileDownloadQuery.__name__,
SandboxFileDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@files_ns.route(
"/sandbox-file-downloads/<string:tenant_id>/<string:sandbox_id>/<string:export_id>/<path:filename>/download"
)
class SandboxFileDownloadDownloadApi(Resource):
def get(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str):
args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True))
try:
export_path = SandboxFileDownloadPath(
tenant_id=UUID(tenant_id),
sandbox_id=UUID(sandbox_id),
export_id=export_id,
filename=filename,
)
except ValueError as exc:
raise Forbidden(str(exc)) from exc
if not SandboxFileSigner.verify_download_signature(
export_path=export_path,
expires_at=args.expires_at,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired download link")
try:
generator = storage.load_stream(export_path.get_storage_key())
except FileNotFoundError as exc:
raise NotFound("File not found") from exc
encoded_filename = quote(filename.split("/")[-1])
return Response(
generator,
mimetype="application/octet-stream",
direct_passthrough=True,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
},
)
@files_ns.route(
"/sandbox-file-downloads/<string:tenant_id>/<string:sandbox_id>/<string:export_id>/<path:filename>/upload"
)
class SandboxFileDownloadUploadApi(Resource):
def put(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str):
args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True))
try:
export_path = SandboxFileDownloadPath(
tenant_id=UUID(tenant_id),
sandbox_id=UUID(sandbox_id),
export_id=export_id,
filename=filename,
)
except ValueError as exc:
raise Forbidden(str(exc)) from exc
if not SandboxFileSigner.verify_upload_signature(
export_path=export_path,
expires_at=args.expires_at,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired upload link")
storage.save(export_path.get_storage_key(), request.get_data())
return Response(status=204)

View File

@ -1,56 +0,0 @@
from urllib.parse import quote, unquote
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import files_ns
from extensions.ext_storage import storage
from extensions.storage.file_presign_storage import FilePresignStorage
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StorageDownloadQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp used in the signature")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
files_ns.schema_model(
StorageDownloadQuery.__name__,
StorageDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@files_ns.route("/storage/<path:filename>/download")
class StorageFileDownloadApi(Resource):
def get(self, filename: str):
filename = unquote(filename)
args = StorageDownloadQuery.model_validate(request.args.to_dict(flat=True))
if not FilePresignStorage.verify_signature(
filename=filename,
timestamp=args.timestamp,
nonce=args.nonce,
sign=args.sign,
):
raise Forbidden("Invalid or expired download link")
try:
generator = storage.load_stream(filename)
except FileNotFoundError:
raise NotFound("File not found")
encoded_filename = quote(filename.split("/")[-1])
return Response(
generator,
mimetype="application/octet-stream",
direct_passthrough=True,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
},
)

View File

@ -0,0 +1,80 @@
"""Token-based file proxy controller for storage operations.
This controller handles file download and upload operations using opaque UUID tokens.
The token maps to the real storage key in Redis, so the actual storage path is never
exposed in the URL.
Routes:
GET /files/storage-files/{token} - Download a file
PUT /files/storage-files/{token} - Upload a file
The operation type (download/upload) is determined by the ticket stored in Redis,
not by the HTTP method. This ensures a download ticket cannot be used for upload
and vice versa.
"""
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
from controllers.files import files_ns
from extensions.ext_storage import storage
from services.storage_ticket_service import StorageTicketService
@files_ns.route("/storage-files/<string:token>")
class StorageFilesApi(Resource):
"""Handle file operations through token-based URLs."""
def get(self, token: str):
"""Download a file using a token.
The ticket must have op="download", otherwise returns 403.
"""
ticket = StorageTicketService.get_ticket(token)
if ticket is None:
raise Forbidden("Invalid or expired token")
if ticket.op != "download":
raise Forbidden("This token is not valid for download")
try:
generator = storage.load_stream(ticket.storage_key)
except FileNotFoundError:
raise NotFound("File not found")
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
encoded_filename = quote(filename)
return Response(
generator,
mimetype="application/octet-stream",
direct_passthrough=True,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
},
)
def put(self, token: str):
"""Upload a file using a token.
The ticket must have op="upload", otherwise returns 403.
If the request body exceeds max_bytes, returns 413.
"""
ticket = StorageTicketService.get_ticket(token)
if ticket is None:
raise Forbidden("Invalid or expired token")
if ticket.op != "upload":
raise Forbidden("This token is not valid for upload")
content = request.get_data()
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
storage.save(ticket.storage_key, content)
return Response(status=204)

View File

@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel):
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
summary_index_setting: dict | None = None
class DatasetUpdatePayload(BaseModel):
@ -217,6 +218,7 @@ class DatasetListApi(DatasetApiResource):
embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=payload.embedding_model,
retrieval_model=payload.retrieval_model,
summary_index_setting=payload.summary_index_setting,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()

View File

@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
Segmentation,
)
from services.file_service import FileService
from services.summary_index_service import SummaryIndexService
class DocumentTextCreatePayload(BaseModel):
@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource):
)
documents = paginated_documents.items
DocumentService.enrich_documents_with_summary_index_status(
documents=documents,
dataset=dataset,
tenant_id=tenant_id,
)
response = {
"data": marshal(documents, document_fields),
"has_more": len(documents) == query_params.limit,
@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource):
if metadata not in self.METADATA_CHOICES:
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
# Calculate summary_index_status if needed
summary_index_status = None
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
if has_summary_index and document.need_summary is True:
summary_index_status = SummaryIndexService.get_document_summary_index_status(
document_id=document_id,
dataset_id=dataset_id,
tenant_id=tenant_id,
)
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"summary_index_status": summary_index_status,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"summary_index_status": summary_index_status,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response

View File

@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC):
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
"summary": resource.get("summary"),
}
)
metadata["retriever_resources"] = updated_resources

View File

@ -1,12 +1,17 @@
from __future__ import annotations
import re
from datetime import UTC, datetime
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.app_asset_entities import AppAssetFileTree
# Constants
BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$")
BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB
MANIFEST_FILENAME = "manifest.json"
MANIFEST_SCHEMA_VERSION = "1.0"
# Exceptions
@ -22,21 +27,70 @@ class ZipSecurityError(Exception):
pass
# Entities
# Manifest DTOs
class ManifestFileEntry(BaseModel):
"""Maps node_id to file path in the bundle."""
model_config = ConfigDict(extra="forbid")
node_id: str
path: str
class ManifestIntegrity(BaseModel):
"""Basic integrity check fields."""
model_config = ConfigDict(extra="forbid")
file_count: int
class ManifestAppAssets(BaseModel):
"""App assets section containing the full tree."""
model_config = ConfigDict(extra="forbid")
tree: AppAssetFileTree
class BundleManifest(BaseModel):
"""
Bundle manifest for app asset import/export.
Schema version 1.0:
- dsl_filename: DSL file name in bundle root (e.g. "my_app.yml")
- tree: Full AppAssetFileTree (files + folders) for 100% restoration including node IDs
- files: Explicit node_id -> path mapping for file nodes only
- integrity: Basic file_count validation
"""
model_config = ConfigDict(extra="forbid")
schema_version: str = Field(default=MANIFEST_SCHEMA_VERSION)
generated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC))
dsl_filename: str = Field(description="DSL file name in bundle root")
app_assets: ManifestAppAssets
files: list[ManifestFileEntry]
integrity: ManifestIntegrity
@property
def assets_prefix(self) -> str:
"""Assets directory name (DSL filename without extension)."""
return self.dsl_filename.rsplit(".", 1)[0]
@classmethod
def from_tree(cls, tree: AppAssetFileTree, dsl_filename: str) -> BundleManifest:
"""Build manifest from an AppAssetFileTree."""
files = [ManifestFileEntry(node_id=n.id, path=tree.get_path(n.id)) for n in tree.walk_files()]
return cls(
dsl_filename=dsl_filename,
app_assets=ManifestAppAssets(tree=tree),
files=files,
integrity=ManifestIntegrity(file_count=len(files)),
)
# Export result
class BundleExportResult(BaseModel):
download_url: str = Field(description="Temporary download URL for the ZIP")
filename: str = Field(description="Suggested filename for the ZIP")
class SourceFileEntry(BaseModel):
path: str = Field(description="File path within the ZIP")
node_id: str = Field(description="Node ID in the asset tree")
class ExtractedFile(BaseModel):
path: str = Field(description="Relative path of the extracted file")
content: bytes = Field(description="File content as bytes")
class ExtractedFolder(BaseModel):
path: str = Field(description="Relative path of the extracted folder")

View File

@ -1,25 +1,31 @@
"""App assets storage layer.
This module provides storage abstractions for app assets (draft files, build zips,
resolved assets, skill bundles, source zips, bundle exports/imports).
Key components:
- AssetPath: Factory for creating typed storage paths
- AppAssetStorage: High-level storage operations with presign support
All presign operations use the unified FilePresignStorage wrapper, which automatically
falls back to Dify's file proxy when the underlying storage doesn't support presigned URLs.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import os
import time
import urllib.parse
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from typing import Any, ClassVar
from uuid import UUID
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
from extensions.storage.cached_presign_storage import CachedPresignStorage
from libs import rsa
from extensions.storage.file_presign_storage import FilePresignStorage
_ASSET_BASE = "app_assets"
_SILENT_STORAGE_NOT_FOUND = b"File Not Found"
_ASSET_PATH_REGISTRY: dict[str, tuple[bool, Callable[..., SignedAssetPath]]] = {}
_ASSET_PATH_REGISTRY: dict[str, tuple[bool, Any]] = {}
def _require_uuid(value: str, field_name: str) -> None:
@ -29,12 +35,14 @@ def _require_uuid(value: str, field_name: str) -> None:
raise ValueError(f"{field_name} must be a UUID") from exc
def register_asset_path(asset_type: str, *, requires_node: bool, factory: Callable[..., SignedAssetPath]) -> None:
def register_asset_path(asset_type: str, *, requires_node: bool, factory: Any) -> None:
_ASSET_PATH_REGISTRY[asset_type] = (requires_node, factory)
@dataclass(frozen=True)
class AssetPathBase(ABC):
"""Base class for all asset paths."""
asset_type: ClassVar[str]
tenant_id: str
app_id: str
@ -50,49 +58,24 @@ class AssetPathBase(ABC):
raise NotImplementedError
class SignedAssetPath(AssetPathBase, ABC):
@abstractmethod
def signature_parts(self) -> tuple[str, str | None]:
"""Return (resource_id, sub_resource_id) used for signing.
sub_resource_id should be None when not applicable.
"""
@abstractmethod
def proxy_path_parts(self) -> list[str]:
raise NotImplementedError
@dataclass(frozen=True)
class _DraftAssetPath(SignedAssetPath):
class _DraftAssetPath(AssetPathBase):
asset_type: ClassVar[str] = "draft"
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/draft/{self.resource_id}"
def signature_parts(self) -> tuple[str, str | None]:
return (self.resource_id, None)
def proxy_path_parts(self) -> list[str]:
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
@dataclass(frozen=True)
class _BuildZipAssetPath(SignedAssetPath):
class _BuildZipAssetPath(AssetPathBase):
asset_type: ClassVar[str] = "build-zip"
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/artifacts/{self.resource_id}.zip"
def signature_parts(self) -> tuple[str, str | None]:
return (self.resource_id, None)
def proxy_path_parts(self) -> list[str]:
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
@dataclass(frozen=True)
class _ResolvedAssetPath(SignedAssetPath):
class _ResolvedAssetPath(AssetPathBase):
asset_type: ClassVar[str] = "resolved"
node_id: str
@ -103,80 +86,76 @@ class _ResolvedAssetPath(SignedAssetPath):
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/artifacts/{self.resource_id}/resolved/{self.node_id}"
def signature_parts(self) -> tuple[str, str | None]:
return (self.resource_id, self.node_id)
def proxy_path_parts(self) -> list[str]:
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id, self.node_id]
@dataclass(frozen=True)
class _SkillBundleAssetPath(SignedAssetPath):
class _SkillBundleAssetPath(AssetPathBase):
asset_type: ClassVar[str] = "skill-bundle"
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/artifacts/{self.resource_id}/skill_artifact_set.json"
def signature_parts(self) -> tuple[str, str | None]:
return (self.resource_id, None)
def proxy_path_parts(self) -> list[str]:
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
@dataclass(frozen=True)
class _SourceZipAssetPath(SignedAssetPath):
class _SourceZipAssetPath(AssetPathBase):
asset_type: ClassVar[str] = "source-zip"
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/sources/{self.resource_id}.zip"
def signature_parts(self) -> tuple[str, str | None]:
return (self.resource_id, None)
def proxy_path_parts(self) -> list[str]:
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
@dataclass(frozen=True)
class _BundleExportZipAssetPath(SignedAssetPath):
class _BundleExportZipAssetPath(AssetPathBase):
asset_type: ClassVar[str] = "bundle-export-zip"
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/bundle_exports/{self.resource_id}.zip"
def signature_parts(self) -> tuple[str, str | None]:
return (self.resource_id, None)
def proxy_path_parts(self) -> list[str]:
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
@dataclass(frozen=True)
class BundleImportZipPath:
"""Path for temporary import zip files."""
tenant_id: str
import_id: str
def __post_init__(self) -> None:
_require_uuid(self.tenant_id, "tenant_id")
def get_storage_key(self) -> str:
return f"{_ASSET_BASE}/{self.tenant_id}/imports/{self.import_id}.zip"
class AssetPath:
"""Factory for creating typed asset paths."""
@staticmethod
def draft(tenant_id: str, app_id: str, node_id: str) -> SignedAssetPath:
def draft(tenant_id: str, app_id: str, node_id: str) -> AssetPathBase:
return _DraftAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=node_id)
@staticmethod
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> SignedAssetPath:
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> AssetPathBase:
return _BuildZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id)
@staticmethod
def resolved(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> SignedAssetPath:
def resolved(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> AssetPathBase:
return _ResolvedAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id, node_id=node_id)
@staticmethod
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> SignedAssetPath:
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> AssetPathBase:
return _SkillBundleAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id)
@staticmethod
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> SignedAssetPath:
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> AssetPathBase:
return _SourceZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=workflow_id)
@staticmethod
def bundle_export_zip(tenant_id: str, app_id: str, export_id: str) -> SignedAssetPath:
def bundle_export_zip(tenant_id: str, app_id: str, export_id: str) -> AssetPathBase:
return _BundleExportZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=export_id)
@staticmethod
def bundle_import_zip(tenant_id: str, import_id: str) -> BundleImportZipPath:
return BundleImportZipPath(tenant_id=tenant_id, import_id=import_id)
@staticmethod
def from_components(
asset_type: str,
@ -184,7 +163,7 @@ class AssetPath:
app_id: str,
resource_id: str,
sub_resource_id: str | None = None,
) -> SignedAssetPath:
) -> AssetPathBase:
entry = _ASSET_PATH_REGISTRY.get(asset_type)
if not entry:
raise ValueError(f"Unsupported asset type: {asset_type}")
@ -206,120 +185,26 @@ register_asset_path("source-zip", requires_node=False, factory=AssetPath.source_
register_asset_path("bundle-export-zip", requires_node=False, factory=AssetPath.bundle_export_zip)
class AppAssetSigner:
SIGNATURE_PREFIX = "app-asset"
SIGNATURE_VERSION = "v1"
OPERATION_DOWNLOAD = "download"
OPERATION_UPLOAD = "upload"
@classmethod
def create_download_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str) -> str:
return cls._create_signature(
asset_path=asset_path,
operation=cls.OPERATION_DOWNLOAD,
expires_at=expires_at,
nonce=nonce,
)
@classmethod
def create_upload_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str) -> str:
return cls._create_signature(
asset_path=asset_path,
operation=cls.OPERATION_UPLOAD,
expires_at=expires_at,
nonce=nonce,
)
@classmethod
def verify_download_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str, sign: str) -> bool:
return cls._verify_signature(
asset_path=asset_path,
operation=cls.OPERATION_DOWNLOAD,
expires_at=expires_at,
nonce=nonce,
sign=sign,
)
@classmethod
def verify_upload_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str, sign: str) -> bool:
return cls._verify_signature(
asset_path=asset_path,
operation=cls.OPERATION_UPLOAD,
expires_at=expires_at,
nonce=nonce,
sign=sign,
)
@classmethod
def _verify_signature(
cls,
*,
asset_path: SignedAssetPath,
operation: str,
expires_at: int,
nonce: str,
sign: str,
) -> bool:
if expires_at <= 0:
return False
expected_sign = cls._create_signature(
asset_path=asset_path,
operation=operation,
expires_at=expires_at,
nonce=nonce,
)
if not hmac.compare_digest(sign, expected_sign):
return False
current_time = int(time.time())
if expires_at < current_time:
return False
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
return False
return True
@classmethod
def _create_signature(cls, *, asset_path: SignedAssetPath, operation: str, expires_at: int, nonce: str) -> str:
key = cls._tenant_key(asset_path.tenant_id)
message = cls._signature_message(
asset_path=asset_path,
operation=operation,
expires_at=expires_at,
nonce=nonce,
)
sign = hmac.new(key, message.encode(), hashlib.sha256).digest()
return base64.urlsafe_b64encode(sign).decode()
@classmethod
def _signature_message(cls, *, asset_path: SignedAssetPath, operation: str, expires_at: int, nonce: str) -> str:
resource_id, sub_resource_id = asset_path.signature_parts()
return (
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
f"{asset_path.asset_type}|{asset_path.tenant_id}|{asset_path.app_id}|"
f"{resource_id}|{sub_resource_id or ''}|{expires_at}|{nonce}"
)
@classmethod
def _tenant_key(cls, tenant_id: str) -> bytes:
try:
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
except rsa.PrivkeyNotFoundError as exc:
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
private_key = rsa_key.export_key()
return hashlib.sha256(private_key).digest()
class AppAssetStorage:
_base_storage: BaseStorage
"""High-level storage operations for app assets.
Wraps BaseStorage with:
- FilePresignStorage for presign fallback support
- CachedPresignStorage for URL caching
Usage:
storage = AppAssetStorage(base_storage, redis_client=redis)
storage.save(asset_path, content)
url = storage.get_download_url(asset_path)
"""
_storage: CachedPresignStorage
def __init__(self, storage: BaseStorage, *, redis_client: Any, cache_key_prefix: str = "app_assets") -> None:
self._base_storage = storage
# Wrap with FilePresignStorage for fallback support, then CachedPresignStorage for caching
presign_storage = FilePresignStorage(storage)
self._storage = CachedPresignStorage(
storage=storage,
storage=presign_storage,
redis_client=redis_client,
cache_key_prefix=cache_key_prefix,
)
@ -329,87 +214,51 @@ class AppAssetStorage:
return self._storage
def save(self, asset_path: AssetPathBase, content: bytes) -> None:
self._storage.save(self.get_storage_key(asset_path), content)
self._storage.save(asset_path.get_storage_key(), content)
def load(self, asset_path: AssetPathBase) -> bytes:
return self._storage.load_once(self.get_storage_key(asset_path))
return self._storage.load_once(asset_path.get_storage_key())
def load_stream(self, asset_path: AssetPathBase) -> Generator[bytes, None, None]:
return self._storage.load_stream(asset_path.get_storage_key())
def load_or_none(self, asset_path: AssetPathBase) -> bytes | None:
try:
data = self._storage.load_once(self.get_storage_key(asset_path))
data = self._storage.load_once(asset_path.get_storage_key())
except FileNotFoundError:
return None
if data == _SILENT_STORAGE_NOT_FOUND:
return None
return data
def exists(self, asset_path: AssetPathBase) -> bool:
return self._storage.exists(asset_path.get_storage_key())
def delete(self, asset_path: AssetPathBase) -> None:
self._storage.delete(self.get_storage_key(asset_path))
self._storage.delete(asset_path.get_storage_key())
def get_storage_key(self, asset_path: AssetPathBase) -> str:
return asset_path.get_storage_key()
def get_download_url(self, asset_path: AssetPathBase, expires_in: int = 3600) -> str:
return self._storage.get_download_url(asset_path.get_storage_key(), expires_in)
def get_download_url(self, asset_path: SignedAssetPath, expires_in: int = 3600) -> str:
storage_key = self.get_storage_key(asset_path)
def get_download_urls(self, asset_paths: Iterable[AssetPathBase], expires_in: int = 3600) -> list[str]:
storage_keys = [p.get_storage_key() for p in asset_paths]
return self._storage.get_download_urls(storage_keys, expires_in)
def get_upload_url(self, asset_path: AssetPathBase, expires_in: int = 3600) -> str:
return self._storage.get_upload_url(asset_path.get_storage_key(), expires_in)
# Bundle import convenience methods
def get_import_upload_url(self, path: BundleImportZipPath, expires_in: int = 3600) -> str:
return self._storage.get_upload_url(path.get_storage_key(), expires_in)
def get_import_download_url(self, path: BundleImportZipPath, expires_in: int = 3600) -> str:
return self._storage.get_download_url(path.get_storage_key(), expires_in)
def delete_import_zip(self, path: BundleImportZipPath) -> None:
"""Delete import zip file. Errors are logged but not raised."""
try:
return self._storage.get_download_url(storage_key, expires_in)
except NotImplementedError:
pass
self._storage.delete(path.get_storage_key())
except Exception:
import logging
return self._generate_signed_proxy_download_url(asset_path, expires_in)
def get_download_urls(
self,
asset_paths: Iterable[SignedAssetPath],
expires_in: int = 3600,
) -> list[str]:
asset_paths_list = list(asset_paths)
storage_keys = [self.get_storage_key(asset_path) for asset_path in asset_paths_list]
try:
return self._storage.get_download_urls(storage_keys, expires_in)
except NotImplementedError:
pass
return [self._generate_signed_proxy_download_url(asset_path, expires_in) for asset_path in asset_paths_list]
def get_upload_url(
self,
asset_path: SignedAssetPath,
expires_in: int = 3600,
) -> str:
storage_key = self.get_storage_key(asset_path)
try:
return self._storage.get_upload_url(storage_key, expires_in)
except NotImplementedError:
pass
return self._generate_signed_proxy_upload_url(asset_path, expires_in)
def _generate_signed_proxy_download_url(self, asset_path: SignedAssetPath, expires_in: int) -> str:
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
expires_at = int(time.time()) + max(expires_in, 1)
nonce = os.urandom(16).hex()
sign = AppAssetSigner.create_download_signature(asset_path=asset_path, expires_at=expires_at, nonce=nonce)
base_url = dify_config.FILES_URL
url = self._build_proxy_url(base_url=base_url, asset_path=asset_path, action="download")
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
return f"{url}?{query}"
def _generate_signed_proxy_upload_url(self, asset_path: SignedAssetPath, expires_in: int) -> str:
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
expires_at = int(time.time()) + max(expires_in, 1)
nonce = os.urandom(16).hex()
sign = AppAssetSigner.create_upload_signature(asset_path=asset_path, expires_at=expires_at, nonce=nonce)
base_url = dify_config.FILES_URL
url = self._build_proxy_url(base_url=base_url, asset_path=asset_path, action="upload")
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
return f"{url}?{query}"
@staticmethod
def _build_proxy_url(*, base_url: str, asset_path: SignedAssetPath, action: str) -> str:
encoded_parts = [urllib.parse.quote(part, safe="") for part in asset_path.proxy_path_parts()]
path = "/".join(encoded_parts)
return f"{base_url}/files/app-assets/{path}/{action}"
logging.getLogger(__name__).debug("Failed to delete import zip: %s", path.get_storage_key())

View File

@ -1,5 +1 @@
from .source_zip_extractor import SourceZipExtractor
__all__ = [
"SourceZipExtractor",
]
# App bundle utilities - manifest-driven import/export handled by AppBundleService

View File

@ -1,98 +0,0 @@
from __future__ import annotations
import io
import zipfile
from typing import TYPE_CHECKING
from uuid import uuid4
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
from core.app.entities.app_bundle_entities import ExtractedFile, ExtractedFolder, ZipSecurityError
from core.app_assets.storage import AssetPath
if TYPE_CHECKING:
from core.app_assets.storage import AppAssetStorage
class SourceZipExtractor:
def __init__(self, storage: AppAssetStorage) -> None:
self._storage = storage
def extract_entries(
self, zip_bytes: bytes, *, expected_prefix: str
) -> tuple[list[ExtractedFolder], list[ExtractedFile]]:
folders: list[ExtractedFolder] = []
files: list[ExtractedFile] = []
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
for info in zf.infolist():
name = info.filename
self._validate_path(name)
if not name.startswith(expected_prefix):
continue
relative_path = name[len(expected_prefix) :].lstrip("/")
if not relative_path:
continue
if info.is_dir():
folders.append(ExtractedFolder(path=relative_path.rstrip("/")))
else:
content = zf.read(info)
files.append(ExtractedFile(path=relative_path, content=content))
return folders, files
def build_tree_and_save(
self,
folders: list[ExtractedFolder],
files: list[ExtractedFile],
tenant_id: str,
app_id: str,
) -> AppAssetFileTree:
tree = AppAssetFileTree()
path_to_node_id: dict[str, str] = {}
all_folder_paths = {f.path for f in folders}
for file in files:
self._ensure_parent_folders(file.path, all_folder_paths)
sorted_folders = sorted(all_folder_paths, key=lambda p: p.count("/"))
for folder_path in sorted_folders:
node_id = str(uuid4())
name = folder_path.rsplit("/", 1)[-1]
parent_path = folder_path.rsplit("/", 1)[0] if "/" in folder_path else None
parent_id = path_to_node_id.get(parent_path) if parent_path else None
node = AppAssetNode.create_folder(node_id, name, parent_id)
tree.add(node)
path_to_node_id[folder_path] = node_id
sorted_files = sorted(files, key=lambda f: f.path)
for file in sorted_files:
node_id = str(uuid4())
name = file.path.rsplit("/", 1)[-1]
parent_path = file.path.rsplit("/", 1)[0] if "/" in file.path else None
parent_id = path_to_node_id.get(parent_path) if parent_path else None
node = AppAssetNode.create_file(node_id, name, parent_id, len(file.content))
tree.add(node)
asset_path = AssetPath.draft(tenant_id, app_id, node_id)
self._storage.save(asset_path, file.content)
return tree
def _validate_path(self, path: str) -> None:
if ".." in path:
raise ZipSecurityError(f"Path traversal detected: {path}")
if path.startswith("/"):
raise ZipSecurityError(f"Absolute path detected: {path}")
if "\\" in path:
raise ZipSecurityError(f"Backslash in path: {path}")
def _ensure_parent_folders(self, file_path: str, folder_set: set[str]) -> None:
parts = file_path.split("/")[:-1]
for i in range(1, len(parts) + 1):
parent = "/".join(parts[:i])
folder_set.add(parent)

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
content: str
summary: str | None = None
child_chunks: list[str] | None = None

View File

@ -123,6 +123,8 @@ def download(f: File, /):
):
return _download_file_content(f.storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
@ -153,6 +155,8 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
data = response.content

View File

@ -4,8 +4,10 @@ Proxy requests to avoid SSRF
import logging
import time
from typing import Any, TypeAlias
import httpx
from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
Headers: TypeAlias = dict[str, str]
_HEADERS_ADAPTER = TypeAdapter(Headers)
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
_SSRF_CLIENT_LIMITS = httpx.Limits(
@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
)
def _get_user_provided_host_header(headers: dict | None) -> str | None:
def _get_user_provided_host_header(headers: Headers | None) -> str | None:
"""
Extract the user-provided Host header from the headers dict.
@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
def _inject_trace_headers(headers: Headers | None) -> Headers:
"""
Inject W3C traceparent header for distributed tracing.
@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict:
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
if not isinstance(verify_option, bool):
raise ValueError("ssl_verify must be a boolean")
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
try:
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
except ValidationError as e:
raise ValueError("headers must be a mapping of string keys to string values") from e
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("GET", url, max_retries=max_retries, **kwargs)
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("POST", url, max_retries=max_retries, **kwargs)
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("PUT", url, max_retries=max_retries, **kwargs)
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("HEAD", url, max_retries=max_retries, **kwargs)

View File

@ -311,14 +311,18 @@ class IndexingRunner:
qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0
# doc_form represents the segmentation method (general, parent-child, QA)
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# one extract_setting is one source document
for extract_setting in extract_settings:
# extract
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
# Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# Cleaning and segmentation
documents = index_processor.transform(
text_docs,
current_user=None,
@ -361,6 +365,12 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
# Generate summary preview
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(

View File

@ -471,7 +471,6 @@ class LLMGenerator:
prompt_messages=complete_messages,
output_model=CodeNodeStructuredOutput,
model_parameters=model_parameters,
stream=True,
tenant_id=tenant_id,
)
@ -553,16 +552,10 @@ class LLMGenerator:
completion_params = model_config.get("completion_params", {}) if model_config else {}
try:
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
output_model=SuggestedQuestionsOutput,
model_parameters=completion_params,
stream=True,
tenant_id=tenant_id,
)
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
model_instance=model_instance, prompt_messages=prompt_messages,
output_model=SuggestedQuestionsOutput,
model_parameters=completion_params, tenant_id=tenant_id)
return {"questions": response.questions, "error": ""}
@ -842,15 +835,11 @@ Generate {language} code to extract/transform available variables for the target
try:
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=list(prompt_messages),
output_model=InstructionModifyOutput,
model_parameters=model_parameters,
stream=True,
)
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
model_instance=model_instance,
prompt_messages=list(prompt_messages),
output_model=InstructionModifyOutput,
model_parameters=model_parameters)
return response.model_dump(mode="python")
except InvokeError as e:
error = str(e)

View File

@ -1,8 +1,8 @@
import json
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import Any, Literal, TypeVar, cast, overload
from typing import Any, TypeVar, cast
import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
@ -14,13 +14,9 @@ from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
@ -52,7 +48,6 @@ TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, Mode
T = TypeVar("T", bound=BaseModel)
@overload
def invoke_llm_with_structured_output(
*,
provider: str,
@ -63,58 +58,10 @@ def invoke_llm_with_structured_output(
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with structured output.
@ -129,7 +76,6 @@ def invoke_llm_with_structured_output(
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:param tenant_id: tenant ID for file reference conversion. When provided and
@ -165,91 +111,33 @@ def invoke_llm_with_structured_output(
model_parameters=model_parameters_with_json_schema,
tools=tools,
stop=stop,
stream=stream,
stream=False,
user=user,
callbacks=callbacks,
)
if isinstance(llm_result, LLMResult):
# Non-streaming result
structured_output = _extract_structured_output(llm_result)
# Non-streaming result
structured_output = _extract_structured_output(llm_result)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
system_fingerprint=llm_result.system_fingerprint,
prompt_messages=llm_result.prompt_messages,
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
else:
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
result_text: str = ""
tool_call_args: dict[str, str] = {} # tool_call_id -> arguments
prompt_messages: Sequence[PromptMessage] = []
system_fingerprint: str | None = None
for event in llm_result:
if isinstance(event, LLMResultChunk):
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
# Collect text content
result_text += event.delta.message.get_text_content()
# Collect tool call arguments
if event.delta.message.tool_calls:
for tool_call in event.delta.message.tool_calls:
call_id = tool_call.id or ""
if tool_call.function.arguments:
tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=event.delta,
)
# Extract structured output: prefer tool call, fallback to text
structured_output = _extract_structured_output_from_stream(result_text, tool_call_args)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
yield LLMResultChunkWithStructuredOutput(
structured_output=structured_output,
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
usage=None,
finish_reason=None,
),
)
return generator()
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
system_fingerprint=llm_result.system_fingerprint,
prompt_messages=llm_result.prompt_messages,
)
def invoke_llm_with_pydantic_model(
@ -262,7 +150,6 @@ def invoke_llm_with_pydantic_model(
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True, # Some model plugin implementations don't support stream=False
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
@ -281,36 +168,6 @@ def invoke_llm_with_pydantic_model(
"""
json_schema = _schema_from_pydantic(output_model)
if stream:
result_generator = invoke_llm_with_structured_output(
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=True,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
)
# Consume the generator to get the final chunk with structured_output
last_chunk: LLMResultChunkWithStructuredOutput | None = None
for chunk in result_generator:
last_chunk = chunk
if last_chunk is None:
raise OutputParserError("No chunks received from LLM")
structured_output = last_chunk.structured_output
if structured_output is None:
raise OutputParserError("Structured output is empty")
return _validate_structured_output(output_model, structured_output)
result = invoke_llm_with_structured_output(
provider=provider,
model_schema=model_schema,
@ -320,7 +177,6 @@ def invoke_llm_with_pydantic_model(
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=False,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
@ -416,7 +272,7 @@ def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
repaired = json_repair.loads(arguments)
if not isinstance(repaired, dict):
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
return cast(dict, repaired)
return repaired
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:

View File

@ -435,3 +435,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
"""Summarize the following content. Extract only the key information and main points. """
"""Remove redundant details.
Requirements:
1. Write a concise summary in plain text
2. Use the same language as the input content
3. Focus on important facts, concepts, and details
4. If images are included, describe their key information
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
6. Write directly without extra words
Output only the summary text. Start summarizing now:
"""
)

View File

@ -1,10 +1,11 @@
import decimal
import hashlib
from threading import Lock
import logging
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from redis import RedisError
import contexts
from configs import dify_config
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class AIModel(BaseModel):
@ -144,34 +148,60 @@ class AIModel(BaseModel):
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
return schema
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""

View File

@ -5,7 +5,11 @@ import logging
from collections.abc import Sequence
from threading import Lock
from pydantic import ValidationError
from redis import RedisError
import contexts
from configs import dify_config
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
@ -175,34 +180,60 @@ class ModelProviderFactory:
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
cached_schema_json = None
try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning(
"Failed to validate cached plugin model schema for model %s",
model,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)
return schema
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def get_models(
self,

View File

@ -114,46 +114,32 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
model_instance=model_instance,
prompt_messages=payload.prompt_messages,
json_schema=payload.structured_output_schema,
model_parameters=payload.completion_params,
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
user=user_id
)
if isinstance(response, Generator):
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
yield chunk
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
return handle_non_streaming(response)
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):

View File

@ -24,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import (
ChildChunk,
Dataset,
DocumentSegment,
DocumentSegmentSummary,
SegmentAttachmentBinding,
)
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
@ -389,15 +395,15 @@ class RetrievalService:
.all()
}
records = []
include_segment_ids = set()
segment_child_map = {}
valid_dataset_documents = {}
image_doc_ids: list[Any] = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
summary_segment_ids = set() # Track segments retrieved via summary
summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score
# First pass: collect all document IDs and identify summary documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
@ -408,16 +414,39 @@ class RetrievalService:
continue
valid_dataset_documents[document_id] = dataset_document
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if original_chunk_id:
summary_segment_ids.add(original_chunk_id)
# Save summary's score for later use
summary_score = document.metadata.get("score")
if summary_score is not None:
try:
summary_score_float = float(summary_score)
# If the same segment has multiple summary hits, take the highest score
if original_chunk_id not in summary_score_map:
summary_score_map[original_chunk_id] = summary_score_float
else:
summary_score_map[original_chunk_id] = max(
summary_score_map[original_chunk_id], summary_score_float
)
except (ValueError, TypeError):
# Skip invalid score values
pass
continue # Skip adding to other lists for summary documents
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
child_index_node_ids.append(doc_id)
else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
@ -433,6 +462,7 @@ class RetrievalService:
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
@ -447,6 +477,7 @@ class RetrievalService:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
@ -470,6 +501,7 @@ class RetrievalService:
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
@ -481,6 +513,40 @@ class RetrievalService:
if index_node_segments:
segments.extend(index_node_segments)
# Handle summary documents: query segments by original_chunk_id
if summary_segment_ids:
summary_segment_ids_list = list(summary_segment_ids)
summary_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(summary_segment_ids_list),
)
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
segments.extend(summary_segments)
# Add summary segment IDs to segment_ids for summary query
for seg in summary_segments:
if seg.id not in segment_ids:
segment_ids.append(seg.id)
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
@ -489,30 +555,44 @@ class RetrievalService:
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
# Check if this segment was retrieved via summary
# Use summary score as base score if available, otherwise 0.0
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id]
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
"score": child_score,
}
child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
max_score = max(max_score, child_score)
for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
file_document = doc_to_document_map.get(attachment_info["id"])
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
segment_child_map[segment.id] = map_detail
else:
# No child chunks or attachments, use summary score if available
summary_score = summary_score_map.get(segment.id)
if summary_score is not None:
segment_child_map[segment.id] = {
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
"segment": segment,
}
@ -520,14 +600,23 @@ class RetrievalService:
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
# Check if this segment was retrieved via summary
# Use summary score if available (summary retrieval takes priority)
max_score = summary_score_map.get(segment.id, 0.0)
# If not retrieved via summary, use original segment's score
if segment.id not in summary_score_map:
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
# Also consider attachment scores
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
@ -576,9 +665,16 @@ class RetrievalService:
else None
)
# Extract summary if this segment was retrieved via summary
summary_content = segment_summary_map.get(segment.id)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks_list, score=score, files=files
segment=segment,
child_chunks=child_chunks_list,
score=score,
files=files,
summary=summary_content,
)
result.append(retrieval_segment)

View File

@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel):
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None
summary: str | None = None

View File

@ -1,4 +1,7 @@
"""Abstract interface for document loader implementations."""
"""Word (.docx) document extractor used for RAG ingestion.
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import logging
import mimetypes
@ -8,7 +11,6 @@ import tempfile
import uuid
from urllib.parse import urlparse
import httpx
from docx import Document as DocxDocument
from docx.oxml.ns import qn
from docx.text.run import Run
@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor):
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
response = httpx.get(self.file_path, timeout=None)
response = ssrf_proxy.get(self.file_path)
if response.status_code != 200:
response.close()
@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor):
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
try:
self.temp_file.write(response.content)
self.temp_file.flush()
finally:
response.close()
self.file_path = self.temp_file.name

View File

@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC):
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
"""
raise NotImplementedError
@abstractmethod
def load(
self,

View File

@ -1,9 +1,27 @@
"""Paragraph index processor."""
import logging
import re
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, cast
logger = logging.getLogger(__name__)
from core.entities.knowledge_entities import PreviewDetail
from core.file import File, FileTransferMethod, FileType, file_manager
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
from models import UploadFile
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
class ParagraphIndexProcessor(BaseIndexProcessor):
@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:
@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
}
else:
raise ValueError("Chunks is not a list")
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item."""
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_texts))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
futures = [executor.submit(process, preview) for preview in preview_texts]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode (indexing-estimate), if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise ValueError(error_summary)
return preview_texts
@staticmethod
def generate_summary(
tenant_id: str,
text: str,
summary_index_setting: dict | None = None,
segment_id: str | None = None,
) -> tuple[str, LLMUsage]:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
and supports vision models by including images from the segment attachments or text content.
Args:
tenant_id: Tenant ID
text: Text content to summarize
summary_index_setting: Summary index configuration
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
"""
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
model_name = summary_index_setting.get("model_name")
model_provider_name = summary_index_setting.get("model_provider_name")
summary_prompt = summary_index_setting.get("summary_prompt")
if not model_name or not model_provider_name:
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
# Import default summary prompt
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)
model_instance = ModelInstance(provider_model_bundle, model_name)
# Get model schema to check if vision is supported
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
# Extract images if model supports vision
image_files = []
if supports_vision:
# First, try to get images from SegmentAttachmentBinding (preferred method)
if segment_id:
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
# If no images from attachments, fall back to extracting from text
if not image_files:
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
# Build prompt messages
prompt_messages = []
if image_files:
# If we have images, create a UserPromptMessage with both text and images
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
# Add images first
for file in image_files:
try:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
)
prompt_message_contents.append(file_content)
except Exception as e:
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
continue
# Add text content
if prompt_message_contents: # Only add text if we successfully added images
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
# If image conversion failed, fall back to text-only
prompt = f"{summary_prompt}\n{text}"
prompt_messages.append(UserPromptMessage(content=prompt))
else:
# No images, use simple text prompt
prompt = f"{summary_prompt}\n{text}"
prompt_messages.append(UserPromptMessage(content=prompt))
result = model_instance.invoke_llm(
prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False
)
# Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator
if not isinstance(result, LLMResult):
raise ValueError("Expected LLMResult when stream=False")
summary_content = getattr(result.message, "content", "")
usage = result.usage
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
return summary_content, usage
@staticmethod
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
"""
Extract images from markdown text and convert them to File objects.
Args:
tenant_id: Tenant ID
text: Text content that may contain markdown image links
Returns:
List of File objects representing images found in the text
"""
# Extract markdown images using regex pattern
pattern = r"!\[.*?\]\((.*?)\)"
images = re.findall(pattern, text)
if not images:
return []
upload_file_id_list = []
for image in images:
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
# Tool files are handled differently, skip for now
continue
if not upload_file_id_list:
return []
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = (
db.session.query(UploadFile)
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
.all()
)
# Create File objects from UploadFile records
file_objects = []
for upload_file in upload_files:
# Only process image files
if not upload_file.mime_type or "image" not in upload_file.mime_type:
continue
mapping = {
"upload_file_id": upload_file.id,
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
"type": FileType.IMAGE.value,
}
try:
file_obj = build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
file_objects.append(file_obj)
except Exception as e:
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
continue
return file_objects
@staticmethod
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
"""
Extract images from SegmentAttachmentBinding table (preferred method).
This matches how DatasetRetrieval gets segment attachments.
Args:
tenant_id: Tenant ID
segment_id: Segment ID to fetch attachments for
Returns:
List of File objects representing images found in segment attachments
"""
from sqlalchemy import select
# Query attachments from SegmentAttachmentBinding table
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment_id,
SegmentAttachmentBinding.tenant_id == tenant_id,
)
).all()
if not attachments_with_bindings:
return []
file_objects = []
for _, upload_file in attachments_with_bindings:
# Only process image files
if not upload_file.mime_type or "image" not in upload_file.mime_type:
continue
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
file_objects.append(file_obj)
except Exception as e:
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
continue
return file_objects

View File

@ -1,11 +1,14 @@
"""Paragraph index processor."""
import json
import logging
import uuid
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class ParentChildIndexProcessor(BaseIndexProcessor):
@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
Note: For parent-child structure, we only generate summaries for parent chunks.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item (parent chunk)."""
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_texts))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
futures = [executor.submit(process, preview) for preview in preview_texts]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode (indexing-estimate), if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise ValueError(error_summary)
return preview_texts

View File

@ -11,6 +11,8 @@ import pandas as pd
from flask import Flask, current_app
from werkzeug.datastructures import FileStorage
from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor):
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Note: qa_model doesn't generate summaries, but we clean them for completeness
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor):
"total_segments": len(qa_chunks.qa_chunks),
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
Note: QA model uses question-answer pairs, which don't require summary generation.
"""
# QA model doesn't generate summaries, return as-is
return preview_texts
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():

View File

@ -236,20 +236,24 @@ class DatasetRetrieval:
if records:
for record in records:
segment = record.segment
# Build content: if summary exists, add it before the segment content
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
segment_content = segment.get_sign_content()
# If summary exists, prepend it to the content
if record.summary:
final_content = f"{record.summary}\n{segment_content}"
else:
final_content = segment_content
document_context_list.append(
DocumentContext(
content=final_content,
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
@ -316,6 +320,9 @@ class DatasetRetrieval:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
# Add summary if this segment was retrieved via summary
if hasattr(record, "summary") and record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)

View File

@ -7,9 +7,9 @@ from uuid import UUID, uuid4
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
from core.sandbox.inspector.base import SandboxFileSource
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath
from core.sandbox.storage import sandbox_file_storage
from core.sandbox.storage.archive_storage import SandboxArchivePath
from core.sandbox.storage.sandbox_file_storage import SandboxFileDownloadPath
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.helpers import execute
from extensions.ext_storage import storage
@ -68,15 +68,14 @@ print(json.dumps(entries))
def _get_archive_download_url(self) -> str:
"""Get a pre-signed download URL for the sandbox archive."""
from extensions.storage.file_presign_storage import FilePresignStorage
archive_path = SandboxArchivePath(tenant_id=UUID(self._tenant_id), sandbox_id=UUID(self._sandbox_id))
storage_key = archive_path.get_storage_key()
if not storage.exists(storage_key):
raise ValueError("Sandbox archive not found")
return SandboxArchiveSigner.build_signed_url(
archive_path=archive_path,
expires_in=self._EXPORT_EXPIRES_IN_SECONDS,
action=SandboxArchiveSigner.OPERATION_DOWNLOAD,
)
presign_storage = FilePresignStorage(storage.storage_runner)
return presign_storage.get_download_url(storage_key, self._EXPORT_EXPIRES_IN_SECONDS)
def _create_zip_sandbox(self) -> ZipSandbox:
"""Create a ZipSandbox instance for archive operations."""

View File

@ -7,8 +7,8 @@ from uuid import UUID, uuid4
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
from core.sandbox.inspector.base import SandboxFileSource
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath
from core.sandbox.storage import sandbox_file_storage
from core.sandbox.storage.sandbox_file_storage import SandboxFileDownloadPath
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.helpers import execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment

View File

@ -1 +0,0 @@
"""Sandbox security helpers."""

View File

@ -1,152 +0,0 @@
from __future__ import annotations
import base64
import hashlib
import hmac
import os
import time
import urllib.parse
from dataclasses import dataclass
from uuid import UUID
from configs import dify_config
from libs import rsa
@dataclass(frozen=True)
class SandboxArchivePath:
tenant_id: UUID
sandbox_id: UUID
def get_storage_key(self) -> str:
return f"sandbox/{self.tenant_id}/{self.sandbox_id}.tar.gz"
def proxy_path(self) -> str:
return f"{self.tenant_id}/{self.sandbox_id}"
class SandboxArchiveSigner:
SIGNATURE_PREFIX = "sandbox-archive"
SIGNATURE_VERSION = "v1"
OPERATION_DOWNLOAD = "download"
OPERATION_UPLOAD = "upload"
@classmethod
def create_download_signature(cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str) -> str:
return cls._create_signature(
archive_path=archive_path,
operation=cls.OPERATION_DOWNLOAD,
expires_at=expires_at,
nonce=nonce,
)
@classmethod
def create_upload_signature(cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str) -> str:
return cls._create_signature(
archive_path=archive_path,
operation=cls.OPERATION_UPLOAD,
expires_at=expires_at,
nonce=nonce,
)
@classmethod
def verify_download_signature(
cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str, sign: str
) -> bool:
return cls._verify_signature(
archive_path=archive_path,
operation=cls.OPERATION_DOWNLOAD,
expires_at=expires_at,
nonce=nonce,
sign=sign,
)
@classmethod
def verify_upload_signature(cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str, sign: str) -> bool:
return cls._verify_signature(
archive_path=archive_path,
operation=cls.OPERATION_UPLOAD,
expires_at=expires_at,
nonce=nonce,
sign=sign,
)
@classmethod
def _verify_signature(
cls,
*,
archive_path: SandboxArchivePath,
operation: str,
expires_at: int,
nonce: str,
sign: str,
) -> bool:
if expires_at <= 0:
return False
expected_sign = cls._create_signature(
archive_path=archive_path,
operation=operation,
expires_at=expires_at,
nonce=nonce,
)
if not hmac.compare_digest(sign, expected_sign):
return False
current_time = int(time.time())
if expires_at < current_time:
return False
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
return False
return True
@classmethod
def build_signed_url(
cls,
*,
archive_path: SandboxArchivePath,
expires_in: int,
action: str,
) -> str:
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
expires_at = int(time.time()) + max(expires_in, 1)
nonce = os.urandom(16).hex()
sign = cls._create_signature(
archive_path=archive_path,
operation=action,
expires_at=expires_at,
nonce=nonce,
)
base_url = dify_config.FILES_URL
url = f"{base_url}/files/sandbox-archives/{archive_path.proxy_path()}/{action}"
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
return f"{url}?{query}"
@classmethod
def _create_signature(
cls,
*,
archive_path: SandboxArchivePath,
operation: str,
expires_at: int,
nonce: str,
) -> str:
key = cls._tenant_key(str(archive_path.tenant_id))
message = (
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
f"{archive_path.tenant_id}|{archive_path.sandbox_id}|{expires_at}|{nonce}"
)
sign = hmac.new(key, message.encode(), hashlib.sha256).digest()
return base64.urlsafe_b64encode(sign).decode()
@classmethod
def _tenant_key(cls, tenant_id: str) -> bytes:
try:
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
except rsa.PrivkeyNotFoundError as exc:
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
private_key = rsa_key.export_key()
return hashlib.sha256(private_key).digest()

View File

@ -1,155 +0,0 @@
from __future__ import annotations
import base64
import hashlib
import hmac
import os
import time
import urllib.parse
from dataclasses import dataclass
from uuid import UUID
from configs import dify_config
from libs import rsa
@dataclass(frozen=True)
class SandboxFileDownloadPath:
tenant_id: UUID
sandbox_id: UUID
export_id: str
filename: str
def get_storage_key(self) -> str:
return f"sandbox_file_downloads/{self.tenant_id}/{self.sandbox_id}/{self.export_id}/{self.filename}"
def proxy_path(self) -> str:
encoded_parts = [
urllib.parse.quote(str(self.tenant_id), safe=""),
urllib.parse.quote(str(self.sandbox_id), safe=""),
urllib.parse.quote(self.export_id, safe=""),
urllib.parse.quote(self.filename, safe=""),
]
return "/".join(encoded_parts)
class SandboxFileSigner:
SIGNATURE_PREFIX = "sandbox-file-download"
SIGNATURE_VERSION = "v1"
OPERATION_DOWNLOAD = "download"
OPERATION_UPLOAD = "upload"
@classmethod
def build_signed_url(
cls,
*,
export_path: SandboxFileDownloadPath,
expires_in: int,
action: str,
) -> str:
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
expires_at = int(time.time()) + max(expires_in, 1)
nonce = os.urandom(16).hex()
sign = cls._create_signature(
export_path=export_path,
operation=action,
expires_at=expires_at,
nonce=nonce,
)
base_url = dify_config.FILES_URL
url = f"{base_url}/files/sandbox-file-downloads/{export_path.proxy_path()}/{action}"
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
return f"{url}?{query}"
@classmethod
def verify_download_signature(
cls,
*,
export_path: SandboxFileDownloadPath,
expires_at: int,
nonce: str,
sign: str,
) -> bool:
return cls._verify_signature(
export_path=export_path,
operation=cls.OPERATION_DOWNLOAD,
expires_at=expires_at,
nonce=nonce,
sign=sign,
)
@classmethod
def verify_upload_signature(
cls,
*,
export_path: SandboxFileDownloadPath,
expires_at: int,
nonce: str,
sign: str,
) -> bool:
return cls._verify_signature(
export_path=export_path,
operation=cls.OPERATION_UPLOAD,
expires_at=expires_at,
nonce=nonce,
sign=sign,
)
@classmethod
def _verify_signature(
cls,
*,
export_path: SandboxFileDownloadPath,
operation: str,
expires_at: int,
nonce: str,
sign: str,
) -> bool:
if expires_at <= 0:
return False
expected_sign = cls._create_signature(
export_path=export_path,
operation=operation,
expires_at=expires_at,
nonce=nonce,
)
if not hmac.compare_digest(sign, expected_sign):
return False
current_time = int(time.time())
if expires_at < current_time:
return False
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
return False
return True
@classmethod
def _create_signature(
cls,
*,
export_path: SandboxFileDownloadPath,
operation: str,
expires_at: int,
nonce: str,
) -> str:
key = cls._tenant_key(str(export_path.tenant_id))
message = (
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
f"{export_path.tenant_id}|{export_path.sandbox_id}|{export_path.export_id}|{export_path.filename}|"
f"{expires_at}|{nonce}"
)
digest = hmac.new(key, message.encode(), hashlib.sha256).digest()
return base64.urlsafe_b64encode(digest).decode()
@classmethod
def _tenant_key(cls, tenant_id: str) -> bytes:
try:
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
except rsa.PrivkeyNotFoundError as exc:
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
private_key = rsa_key.export_key()
return hashlib.sha256(private_key).digest()

View File

@ -1,11 +1,13 @@
from .archive_storage import ArchiveSandboxStorage
from .archive_storage import ArchiveSandboxStorage, SandboxArchivePath
from .noop_storage import NoopSandboxStorage
from .sandbox_file_storage import SandboxFileStorage, sandbox_file_storage
from .sandbox_file_storage import SandboxFileDownloadPath, SandboxFileStorage, sandbox_file_storage
from .sandbox_storage import SandboxStorage
__all__ = [
"ArchiveSandboxStorage",
"NoopSandboxStorage",
"SandboxArchivePath",
"SandboxFileDownloadPath",
"SandboxFileStorage",
"SandboxStorage",
"sandbox_file_storage",

View File

@ -1,18 +1,31 @@
"""Archive-based sandbox storage for persisting sandbox state.
This module provides storage operations for sandbox workspace archives (tar.gz),
enabling state persistence across sandbox sessions.
Storage key format: sandbox/{tenant_id}/{sandbox_id}.tar.gz
All presign operations use the unified FilePresignStorage wrapper, which automatically
falls back to Dify's file proxy when the underlying storage doesn't support presigned URLs.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from uuid import UUID
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
from core.virtual_environment.__base.exec import PipelineExecutionError
from core.virtual_environment.__base.helpers import pipeline
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from extensions.ext_storage import storage
from extensions.storage.base_storage import BaseStorage
from extensions.storage.file_presign_storage import FilePresignStorage
from .sandbox_storage import SandboxStorage
logger = logging.getLogger(__name__)
WORKSPACE_DIR = "."
ARCHIVE_DOWNLOAD_TIMEOUT = 60 * 5
ARCHIVE_UPLOAD_TIMEOUT = 60 * 5
@ -21,40 +34,67 @@ def build_tar_exclude_args(patterns: list[str]) -> list[str]:
return [f"--exclude={p}" for p in patterns]
@dataclass(frozen=True)
class SandboxArchivePath:
"""Path for sandbox workspace archives."""
tenant_id: UUID
sandbox_id: UUID
def get_storage_key(self) -> str:
return f"sandbox/{self.tenant_id}/{self.sandbox_id}.tar.gz"
class ArchiveSandboxStorage(SandboxStorage):
"""Archive-based storage for sandbox workspace persistence.
Uses tar.gz archives to save and restore sandbox workspace state.
Requires a presign-capable storage wrapper for generating download/upload URLs.
"""
_tenant_id: str
_sandbox_id: str
_exclude_patterns: list[str]
_storage: FilePresignStorage
def __init__(self, tenant_id: str, sandbox_id: str, exclude_patterns: list[str] | None = None):
def __init__(
self,
tenant_id: str,
sandbox_id: str,
storage: BaseStorage,
exclude_patterns: list[str] | None = None,
):
self._tenant_id = tenant_id
self._sandbox_id = sandbox_id
self._exclude_patterns = exclude_patterns or []
# Wrap with FilePresignStorage for presign fallback support
self._storage = FilePresignStorage(storage)
@property
def _archive_path(self) -> SandboxArchivePath:
return SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id))
@property
def _storage_key(self) -> str:
return SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id)).get_storage_key()
return self._archive_path.get_storage_key()
@property
def _archive_name(self) -> str:
return f"{self._sandbox_id}.tar.gz"
@property
def _archive_path(self) -> str:
def _archive_tmp_path(self) -> str:
return f"/tmp/{self._archive_name}"
def mount(self, sandbox: VirtualEnvironment) -> bool:
"""Load archive from storage into sandbox workspace."""
if not self.exists():
logger.debug("No archive found for sandbox %s, skipping mount", self._sandbox_id)
return False
archive_path = SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id))
download_url = SandboxArchiveSigner.build_signed_url(
archive_path=archive_path,
expires_in=ARCHIVE_DOWNLOAD_TIMEOUT,
action=SandboxArchiveSigner.OPERATION_DOWNLOAD,
)
download_url = self._storage.get_download_url(self._storage_key, ARCHIVE_DOWNLOAD_TIMEOUT)
archive_name = self._archive_name
try:
(
pipeline(sandbox)
@ -74,13 +114,10 @@ class ArchiveSandboxStorage(SandboxStorage):
return True
def unmount(self, sandbox: VirtualEnvironment) -> bool:
archive_path = SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id))
upload_url = SandboxArchiveSigner.build_signed_url(
archive_path=archive_path,
expires_in=ARCHIVE_UPLOAD_TIMEOUT,
action=SandboxArchiveSigner.OPERATION_UPLOAD,
)
archive_path = self._archive_path
"""Save sandbox workspace to storage as archive."""
upload_url = self._storage.get_upload_url(self._storage_key, ARCHIVE_UPLOAD_TIMEOUT)
archive_path = self._archive_tmp_path
(
pipeline(sandbox)
.add(
@ -105,11 +142,13 @@ class ArchiveSandboxStorage(SandboxStorage):
return True
def exists(self) -> bool:
return storage.exists(self._storage_key)
"""Check if archive exists in storage."""
return self._storage.exists(self._storage_key)
def delete(self) -> None:
"""Delete archive from storage."""
try:
storage.delete(self._storage_key)
self._storage.delete(self._storage_key)
logger.info("Deleted archive for sandbox %s", self._sandbox_id)
except Exception:
logger.exception("Failed to delete archive for sandbox %s", self._sandbox_id)

View File

@ -1,23 +1,58 @@
"""Sandbox file storage for exporting files from sandbox environments.
This module provides storage operations for files exported from sandbox environments,
including download tickets for both runtime and archive-based file sources.
Storage key format: sandbox_file_downloads/{tenant_id}/{sandbox_id}/{export_id}/{filename}
All presign operations use the unified FilePresignStorage wrapper, which automatically
falls back to Dify's file proxy when the underlying storage doesn't support presigned URLs.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from uuid import UUID
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.base_storage import BaseStorage
from extensions.storage.cached_presign_storage import CachedPresignStorage
from extensions.storage.silent_storage import SilentStorage
from extensions.storage.file_presign_storage import FilePresignStorage
@dataclass(frozen=True)
class SandboxFileDownloadPath:
"""Path for sandbox file exports."""
tenant_id: UUID
sandbox_id: UUID
export_id: str
filename: str
def get_storage_key(self) -> str:
return f"sandbox_file_downloads/{self.tenant_id}/{self.sandbox_id}/{self.export_id}/{self.filename}"
class SandboxFileStorage:
_base_storage: BaseStorage
"""Storage operations for sandbox file exports.
Wraps BaseStorage with:
- FilePresignStorage for presign fallback support
- CachedPresignStorage for URL caching
Usage:
storage = SandboxFileStorage(base_storage, redis_client=redis)
storage.save(download_path, content)
url = storage.get_download_url(download_path)
"""
_storage: CachedPresignStorage
def __init__(self, storage: BaseStorage, *, redis_client: Any) -> None:
self._base_storage = storage
# Wrap with FilePresignStorage for fallback support, then CachedPresignStorage for caching
presign_storage = FilePresignStorage(storage)
self._storage = CachedPresignStorage(
storage=storage,
storage=presign_storage,
redis_client=redis_client,
cache_key_prefix="sandbox_file_downloads",
)
@ -26,29 +61,19 @@ class SandboxFileStorage:
self._storage.save(download_path.get_storage_key(), content)
def get_download_url(self, download_path: SandboxFileDownloadPath, expires_in: int = 3600) -> str:
storage_key = download_path.get_storage_key()
try:
return self._storage.get_download_url(storage_key, expires_in)
except NotImplementedError:
return SandboxFileSigner.build_signed_url(
export_path=download_path,
expires_in=expires_in,
action=SandboxFileSigner.OPERATION_DOWNLOAD,
)
return self._storage.get_download_url(download_path.get_storage_key(), expires_in)
def get_upload_url(self, download_path: SandboxFileDownloadPath, expires_in: int = 3600) -> str:
storage_key = download_path.get_storage_key()
try:
return self._storage.get_upload_url(storage_key, expires_in)
except NotImplementedError:
return SandboxFileSigner.build_signed_url(
export_path=download_path,
expires_in=expires_in,
action=SandboxFileSigner.OPERATION_UPLOAD,
)
return self._storage.get_upload_url(download_path.get_storage_key(), expires_in)
class _LazySandboxFileStorage:
"""Lazy initializer for singleton SandboxFileStorage.
Delays storage initialization until first access, ensuring Flask app
context is available.
"""
_instance: SandboxFileStorage | None
def __init__(self) -> None:
@ -56,12 +81,16 @@ class _LazySandboxFileStorage:
def _get_instance(self) -> SandboxFileStorage:
if self._instance is None:
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
if not hasattr(storage, "storage_runner"):
raise RuntimeError(
"Storage is not initialized; call storage.init_app before using sandbox_file_storage"
)
self._instance = SandboxFileStorage(
storage=SilentStorage(storage.storage_runner), redis_client=redis_client
storage=storage.storage_runner,
redis_client=redis_client,
)
return self._instance
@ -69,4 +98,4 @@ class _LazySandboxFileStorage:
return getattr(self._get_instance(), name)
sandbox_file_storage = _LazySandboxFileStorage()
sandbox_file_storage: SandboxFileStorage = _LazySandboxFileStorage() # type: ignore[assignment]

View File

@ -2,21 +2,40 @@ import logging
from core.app_assets.storage import AssetPath
from core.skill.entities.skill_bundle import SkillBundle
from extensions.ext_redis import redis_client
from services.app_asset_service import AppAssetService
logger = logging.getLogger(__name__)
class SkillManager:
_CACHE_KEY_PREFIX = "skill_bundle"
_CACHE_TTL_SECONDS = 60 * 60 * 24
@staticmethod
def get_cache_key(
tenant_id: str,
app_id: str,
assets_id: str,
) -> str:
return f"{SkillManager._CACHE_KEY_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
@staticmethod
def load_bundle(
tenant_id: str,
app_id: str,
assets_id: str,
) -> SkillBundle:
cache_key = SkillManager.get_cache_key(tenant_id, app_id, assets_id)
data = redis_client.get(cache_key)
if data:
return SkillBundle.model_validate_json(data)
asset_path = AssetPath.skill_bundle(tenant_id, app_id, assets_id)
data = AppAssetService.get_storage().load(asset_path)
return SkillBundle.model_validate_json(data)
bundle = SkillBundle.model_validate_json(data)
redis_client.setex(cache_key, SkillManager._CACHE_TTL_SECONDS, bundle.model_dump_json(indent=2).encode("utf-8"))
return bundle
@staticmethod
def save_bundle(
@ -30,3 +49,5 @@ class SkillManager:
asset_path,
bundle.model_dump_json(indent=2).encode("utf-8"),
)
cache_key = SkillManager.get_cache_key(tenant_id, app_id, assets_id)
redis_client.delete(cache_key)

View File

@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if records:
for record in records:
segment = record.segment
# Build content: if summary exists, add it before the segment content
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
segment_content = segment.get_sign_content()
# If summary exists, prepend it to the content
if record.summary:
final_content = f"{record.summary}\n{segment_content}"
else:
final_content = segment_content
document_context_list.append(
DocumentContext(
content=final_content,
score=record.score,
)
)
if self.return_resource:
for record in records:
@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
# Add summary if this segment was retrieved via summary
if hasattr(record, "summary") and record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:

View File

@ -148,7 +148,8 @@ class DockerDemuxer:
to periodically check for errors and closed state instead of blocking forever.
"""
if self._error:
raise TransportEOFError(f"Demuxer error: {self._error}") from self._error
error = cast(BaseException, self._error)
raise TransportEOFError(f"Demuxer error: {error}") from error
while True:
try:
@ -163,7 +164,8 @@ class DockerDemuxer:
if self._closed:
raise TransportEOFError("Demuxer closed")
if self._error:
raise TransportEOFError(f"Demuxer error: {self._error}") from self._error
error = cast(BaseException, self._error)
raise TransportEOFError(f"Demuxer error: {error}") from error
# No error, continue waiting
def close(self) -> None:
@ -292,6 +294,8 @@ class DockerDaemonEnvironment(VirtualEnvironment):
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
# Import Docker SDK lazily so it is loaded after gevent monkey-patching.
import docker.errors
import docker
docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK)
@ -364,6 +368,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
NOTE: I guess nobody will use more than 5 different docker sockets in practice....
"""
import docker
return docker.DockerClient(base_url=docker_sock)
@classmethod
@ -373,6 +378,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
Get the Docker low-level API client.
"""
import docker
return docker.APIClient(base_url=docker_sock)
def get_docker_sock(self) -> str:
@ -431,6 +437,12 @@ class DockerDaemonEnvironment(VirtualEnvironment):
return self._container_path(path)
def upload_file(self, path: str, content: BytesIO) -> None:
"""Upload a file to the container.
Files and intermediate directories are created with world-writable permissions
(0o777 for directories, 0o666 for files) to avoid permission issues when the container
runs as a non-root user but Docker's put_archive creates files as root.
"""
container = self._get_container()
normalized = PurePosixPath(path)
@ -442,6 +454,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
tar_info = tarfile.TarInfo(name=file_name)
tar_info.size = len(payload)
tar_info.mode = 0o666
tar.addfile(tar_info, BytesIO(payload))
tar_stream.seek(0)
container.put_archive(parent_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
@ -454,8 +467,18 @@ class DockerDaemonEnvironment(VirtualEnvironment):
payload = content.getvalue()
tar_stream = BytesIO()
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
# Add intermediate directories with proper permissions
for i in range(len(relative_path.parts) - 1):
dir_path = PurePosixPath(*relative_path.parts[: i + 1])
dir_info = tarfile.TarInfo(name=dir_path.as_posix() + "/")
dir_info.type = tarfile.DIRTYPE
dir_info.mode = 0o777
tar.addfile(dir_info)
# Add the file
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
tar_info.size = len(payload)
tar_info.mode = 0o666
tar.addfile(tar_info, BytesIO(payload))
tar_stream.seek(0)
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
@ -479,7 +502,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
return BytesIO(extracted.read())
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
import docker
import docker.errors
container = self._get_container()
container_path = self._container_path(directory_path)
@ -525,7 +548,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
pass
def release_environment(self) -> None:
import docker
import docker.errors
try:
container = self._get_container()

View File

@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None

View File

@ -1,9 +1,11 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
from .entities import KnowledgeIndexNodeData
from .exc import (
@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
)
.scalar()
)
# Update need_summary based on dataset's summary_index_setting
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
document.need_summary = True
else:
document.need_summary = False
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
db.session.commit()
# Generate summary index if enabled
self._handle_summary_index_generation(dataset, document, variable_pool)
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
"display_status": "completed",
}
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
def _handle_summary_index_generation(
self,
dataset: Dataset,
document: Document,
variable_pool: VariablePool,
) -> None:
"""
Handle summary index generation based on mode (debug/preview or production).
Args:
dataset: Dataset containing the document
document: Document to generate summaries for
variable_pool: Variable pool to check invoke_from
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
# Skip qa_model documents
if document.doc_form == "qa_model":
return
# Determine if in preview/debug mode
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
if is_preview:
try:
# Query segments that need summary generation
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
segments = query.all()
if not segments:
logger.info("No segments found for document %s", document.id)
return
# Filter segments based on mode
segments_to_process = []
for segment in segments:
# Skip if summary already exists
existing_summary = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
.first()
)
if existing_summary:
continue
# For parent-child mode, all segments are parent chunks, so process all
segments_to_process.append(segment)
if not segments_to_process:
logger.info("No segments need summary generation for document %s", document.id)
return
# Use ThreadPoolExecutor for concurrent generation
flask_app = current_app._get_current_object() # type: ignore
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
def process_segment(segment: DocumentSegment) -> None:
"""Process a single segment in a thread with Flask app context."""
with flask_app.app_context():
try:
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
except Exception:
logger.exception(
"Failed to generate summary for segment %s",
segment.id,
)
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
# Wait for all tasks to complete
concurrent.futures.wait(futures)
logger.info(
"Successfully generated summary index for %s segments in document %s",
len(segments_to_process),
document.id,
)
except Exception:
logger.exception("Failed to generate summary index for document %s", document.id)
# Don't fail the entire indexing process if summary generation fails
else:
# Production mode: asynchronous generation
logger.info(
"Queuing summary index generation task for document %s (production mode)",
document.id,
)
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info("Summary index generation task queued for document %s", document.id)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document.id,
)
# Don't fail the entire indexing process if task queuing fails
def _get_preview_output_with_summaries(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
This method generates summaries on-the-fly without saving to database.
Args:
chunk_structure: Chunk structure type
chunks: Chunks to generate preview for
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
preview_output = index_processor.format_preview(chunks)
# Check if summary index is enabled
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
# Generate summaries for chunks
if "preview" in preview_output and isinstance(preview_output["preview"], list):
chunk_count = len(preview_output["preview"])
logger.info(
"Generating summaries for %s chunks in preview mode (dataset: %s)",
chunk_count,
dataset.id,
)
# Use ParagraphIndexProcessor's generate_summary method
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get Flask app for application context in worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: dict) -> None:
"""Generate summary for a single chunk."""
if "content" in preview_item:
# Set Flask application context in worker thread
if flask_app:
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item)
for preview_item in preview_output["preview"]
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode, if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise KnowledgeIndexNodeError(error_summary)
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
logger.info(
"Completed summary generation for preview chunks: %s/%s succeeded",
completed_count,
len(preview_output["preview"]),
)
return preview_output
def _get_preview_output(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset | None = None,
variable_pool: VariablePool | None = None,
) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# If dataset is provided, try to enrich preview with summaries
if dataset and variable_pool:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document:
# Query summaries for this document
summaries = (
db.session.query(DocumentSegmentSummary)
.filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
.all()
)
if summaries:
# Create a map of segment content to summary for matching
# Use content matching as chunks in preview might not be indexed yet
summary_by_content = {}
for summary in summaries:
segment = (
db.session.query(DocumentSegment)
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
.first()
)
if segment:
# Normalize content for matching (strip whitespace)
normalized_content = segment.content.strip()
summary_by_content[normalized_content] = summary.summary_content
# Enrich preview with summaries by content matching
if "preview" in preview_output and isinstance(preview_output["preview"], list):
matched_count = 0
for preview_item in preview_output["preview"]:
if "content" in preview_item:
# Normalize content for matching
normalized_chunk_content = preview_item["content"].strip()
if normalized_chunk_content in summary_by_content:
preview_item["summary"] = summary_by_content[normalized_chunk_content]
matched_count += 1
if matched_count > 0:
logger.info(
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
matched_count,
dataset.id,
document.id,
)
return preview_output
@classmethod
def version(cls) -> str:

View File

@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
# Add summary if available
if record.summary:
source["summary"] = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(

View File

@ -522,7 +522,6 @@ class LLMNode(Node[LLMNodeData]):
json_schema=output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=False,
user=user_id,
tenant_id=tenant_id,
)
@ -1093,6 +1092,8 @@ class LLMNode(Node[LLMNodeData]):
if "content" not in item:
raise InvalidContextStructureError(f"Invalid context structure: {item}")
if item.get("summary"):
context_str += item["summary"] + "\n"
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
@ -1154,6 +1155,7 @@ class LLMNode(Node[LLMNodeData]):
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
summary=context_dict.get("summary"),
)
return source
@ -1915,6 +1917,7 @@ class LLMNode(Node[LLMNodeData]):
) -> Generator[NodeEventBase, None, LLMGenerationData]:
result: LLMGenerationData | None = None
# FIXME(Mairuis): Async processing for bash session.
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
prompt_files = self._extract_prompt_files(variable_pool)
model_features = self._get_model_features(model_instance)

View File

@ -1,7 +1,8 @@
from .zip_sandbox import SandboxDownloadItem, SandboxFile, ZipSandbox
from .zip_sandbox import SandboxDownloadItem, SandboxFile, SandboxUploadItem, ZipSandbox
__all__ = [
"SandboxDownloadItem",
"SandboxFile",
"SandboxUploadItem",
"ZipSandbox",
]

View File

@ -27,10 +27,20 @@ from .strategy import ZipStrategy
@dataclass(frozen=True)
class SandboxDownloadItem:
"""Item for downloading: URL -> sandbox path."""
url: str
path: str
@dataclass(frozen=True)
class SandboxUploadItem:
"""Item for uploading: sandbox path -> URL."""
path: str
url: str
@dataclass(frozen=True)
class SandboxFile:
"""A handle to a file in the sandbox."""
@ -210,25 +220,6 @@ class ZipSandbox:
# ========== Download operations ==========
def download(self, urls: list[str], *, dest_dir: str = ".") -> list[str]:
if not urls:
return []
dest_dir = self._normalize_path(dest_dir)
paths = [self._dest_path_for_url(dest_dir, u) for u in urls]
p = pipeline(self.vm)
p.add(["mkdir", "-p", dest_dir], error_message="Failed to create download directory")
for url, out_path in zip(urls, paths, strict=True):
p.add(["curl", "-fsSL", url, "-o", out_path], error_message="Failed to download file")
try:
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
except Exception as exc:
raise RuntimeError(str(exc)) from exc
return paths
def download_items(self, items: list[SandboxDownloadItem], *, dest_dir: str = ".") -> list[str]:
if not items:
return []
@ -286,6 +277,32 @@ class ZipSandbox:
except CommandExecutionError as exc:
raise RuntimeError(str(exc)) from exc
def upload_items(self, items: list[SandboxUploadItem], *, src_dir: str = ".") -> None:
"""Upload multiple files from sandbox to target URLs.
Args:
items: List of SandboxUploadItem(path, url)
src_dir: Base directory containing the files
"""
if not items:
return
src_dir = self._normalize_path(src_dir)
p = pipeline(self.vm)
for item in items:
rel = self._normalize_path(item.path)
src_path = posixpath.join(src_dir, rel) if src_dir not in ("", ".") else rel
p.add(
["curl", "-fsSL", "-X", "PUT", "-T", src_path, item.url],
error_message=f"Failed to upload {item.path}",
)
try:
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
except Exception as exc:
raise RuntimeError(str(exc)) from exc
# ========== Archive operations ==========
def zip(self, src: str = ".", *, include_base: bool = True) -> SandboxFile:

View File

@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
imports = [
"tasks.async_workflow_tasks", # trigger workers
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.generate_summary_index_task", # summary index generation
"tasks.regenerate_summary_index_task", # summary index regeneration
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME

View File

@ -1,71 +1,56 @@
"""Storage wrapper that provides presigned URL support with fallback to signed proxy URLs."""
"""Storage wrapper that provides presigned URL support with fallback to ticket-based URLs.
import base64
import hashlib
import hmac
import os
import time
import urllib.parse
This is the unified presign wrapper for all storage operations. When the underlying
storage backend doesn't support presigned URLs (raises NotImplementedError), it falls
back to generating ticket-based URLs that route through Dify's file proxy endpoints.
Usage:
from extensions.storage.file_presign_storage import FilePresignStorage
# Wrap any BaseStorage to add presign support
presign_storage = FilePresignStorage(base_storage)
download_url = presign_storage.get_download_url("path/to/file.txt", expires_in=3600)
upload_url = presign_storage.get_upload_url("path/to/file.txt", expires_in=3600)
When the underlying storage doesn't support presigned URLs, the fallback URLs follow the format:
{FILES_URL}/files/storage-tickets/{token}
The token is a UUID that maps to the real storage key in Redis.
"""
from configs import dify_config
from extensions.storage.storage_wrapper import StorageWrapper
class FilePresignStorage(StorageWrapper):
"""Storage wrapper that provides presigned URL support.
"""Storage wrapper that provides presigned URL support with ticket fallback.
If the wrapped storage supports presigned URLs, delegates to it.
Otherwise, generates signed proxy URLs for download.
Otherwise, generates ticket-based URLs for both download and upload operations.
"""
SIGNATURE_PREFIX = "storage-download"
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
"""Get a presigned download URL, falling back to ticket URL if not supported."""
try:
return super().get_download_url(filename, expires_in)
return self._storage.get_download_url(filename, expires_in)
except NotImplementedError:
return self._generate_signed_proxy_url(filename, expires_in)
from services.storage_ticket_service import StorageTicketService
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
try:
return super().get_upload_url(filename, expires_in)
except NotImplementedError:
return self._generate_signed_upload_url(filename)
return StorageTicketService.create_download_url(filename, expires_in=expires_in)
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
"""Get presigned download URLs for multiple files."""
try:
return super().get_download_urls(filenames, expires_in)
return self._storage.get_download_urls(filenames, expires_in)
except NotImplementedError:
return [self._generate_signed_proxy_url(filename, expires_in) for filename in filenames]
from services.storage_ticket_service import StorageTicketService
def _generate_signed_upload_url(self, filename: str) -> str:
# TODO: Implement this
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
def _generate_signed_proxy_url(self, filename: str, expires_in: int = 3600) -> str:
base_url = dify_config.FILES_URL
encoded_filename = urllib.parse.quote(filename, safe="")
url = f"{base_url}/files/storage/{encoded_filename}/download"
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
"""Get a presigned upload URL, falling back to ticket URL if not supported."""
try:
return self._storage.get_upload_url(filename, expires_in)
except NotImplementedError:
from services.storage_ticket_service import StorageTicketService
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
sign = self._create_signature(filename, timestamp, nonce)
query = urllib.parse.urlencode({"timestamp": timestamp, "nonce": nonce, "sign": sign})
return f"{url}?{query}"
@classmethod
def _create_signature(cls, filename: str, timestamp: str, nonce: str) -> str:
key = dify_config.SECRET_KEY.encode()
msg = f"{cls.SIGNATURE_PREFIX}|{filename}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
return base64.urlsafe_b64encode(sign).decode()
@classmethod
def verify_signature(cls, *, filename: str, timestamp: str, nonce: str, sign: str) -> bool:
expected_sign = cls._create_signature(filename, timestamp, nonce)
if sign != expected_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
return StorageTicketService.create_upload_url(filename, expires_in=expires_in)

View File

@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
dataset_summary_index_fields = {
"enable": fields.Boolean,
"model_name": fields.String,
"model_provider_name": fields.String,
"summary_prompt": fields.String,
}
external_retrieval_model_fields = {
"top_k": fields.Integer,
"score_threshold": fields.Float,
@ -83,6 +91,7 @@ dataset_detail_fields = {
"embedding_model_provider": fields.String,
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),

View File

@ -33,6 +33,11 @@ document_fields = {
"hit_count": fields.Integer,
"doc_form": fields.String,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
# Summary index generation status:
# "SUMMARIZING" (when task is queued and generating)
"summary_index_status": fields.String,
# Whether this document needs summary index generation
"need_summary": fields.Boolean,
}
document_with_segments_fields = {
@ -60,6 +65,10 @@ document_with_segments_fields = {
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
# Summary index generation status:
# "SUMMARIZING" (when task is queued and generating)
"summary_index_status": fields.String,
"need_summary": fields.Boolean, # Whether this document needs summary index generation
}
dataset_and_document_fields = {

View File

@ -58,4 +58,5 @@ hit_testing_record_fields = {
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
"summary": fields.String, # Summary content if retrieved via summary index
}

View File

@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel):
segment_position: int | None = None
index_node_hash: str | None = None
content: str | None = None
summary: str | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")

View File

@ -49,4 +49,5 @@ segment_fields = {
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
"summary": fields.String, # Summary content for the segment
}

View File

@ -1,7 +1,7 @@
"""sandbox_providers
Revision ID: aab323465866
Revises: 9d77545f524e
Revises: 788d3099ae3a
Create Date: 2026-01-08 10:31:05.062722
"""
@ -11,7 +11,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'aab323465866'
down_revision = '9d77545f524e'
down_revision = '788d3099ae3a'
branch_labels = None
depends_on = None

View File

@ -0,0 +1,107 @@
"""add summary index feature
Revision ID: 788d3099ae3a
Revises: 9d77545f524e
Create Date: 2026-01-27 18:15:45.277928
"""
from alembic import op
import models as models
import sqlalchemy as sa
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '788d3099ae3a'
down_revision = '9d77545f524e'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
op.create_table('document_segment_summaries',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
sa.Column('summary_content', models.types.LongText(), nullable=True),
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
sa.Column('tokens', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
sa.Column('error', models.types.LongText(), nullable=True),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('disabled_at', sa.DateTime(), nullable=True),
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey')
)
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
else:
# MySQL: Use compatible syntax
op.create_table(
'document_segment_summaries',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
sa.Column('summary_content', models.types.LongText(), nullable=True),
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
sa.Column('tokens', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
sa.Column('error', models.types.LongText(), nullable=True),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('disabled_at', sa.DateTime(), nullable=True),
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'),
)
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('need_summary')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('summary_index_setting')
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
batch_op.drop_index('document_segment_summaries_status_idx')
batch_op.drop_index('document_segment_summaries_document_id_idx')
batch_op.drop_index('document_segment_summaries_dataset_id_idx')
batch_op.drop_index('document_segment_summaries_chunk_id_idx')
op.drop_table('document_segment_summaries')
# ### end Alembic commands ###

View File

@ -72,6 +72,7 @@ class Dataset(Base):
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
@ -419,6 +420,7 @@ class Document(Base):
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base):
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class DocumentSegmentSummary(Base):
__tablename__ = "document_segment_summaries"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"),
sa.Index("document_segment_summaries_document_id_idx", "document_id"),
sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"),
sa.Index("document_segment_summaries_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def __repr__(self):
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"

View File

@ -659,16 +659,22 @@ class AccountTrialAppRecord(Base):
return user
class ExporleBanner(Base):
class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
language: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
)
class OAuthProviderApp(TypeBase):

View File

@ -181,6 +181,7 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
]
############################################################

View File

@ -18,7 +18,6 @@ from core.app_assets.storage import AppAssetStorage, AssetPath
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.silent_storage import SilentStorage
from models.app_asset import AppAssets
from models.model import App
@ -43,9 +42,12 @@ class AppAssetService:
This method creates an AppAssetStorage each time it's called,
ensuring storage.storage_runner is only accessed after init_app.
The storage is wrapped with FilePresignStorage for presign fallback support
and CachedPresignStorage for URL caching.
"""
return AppAssetStorage(
storage=SilentStorage(storage.storage_runner),
storage=storage.storage_runner,
redis_client=redis_client,
cache_key_prefix="app_assets",
)
@ -54,6 +56,22 @@ class AppAssetService:
def _lock(app_id: str):
return redis_client.lock(f"app_asset:lock:{app_id}", timeout=AppAssetService._LOCK_TIMEOUT_SECONDS)
@staticmethod
def get_assets_by_version(tenant_id: str, app_id: str, workflow_id: str | None = None) -> AppAssets:
"""Get asset tree by workflow_id (published) or draft if workflow_id is None."""
with Session(db.engine) as session:
version = workflow_id or AppAssets.VERSION_DRAFT
assets = (
session.query(AppAssets)
.filter(
AppAssets.tenant_id == tenant_id,
AppAssets.app_id == app_id,
AppAssets.version == version,
)
.first()
)
return assets or AppAssets(tenant_id=tenant_id, app_id=app_id, version=version)
@staticmethod
def get_draft_assets(tenant_id: str, app_id: str) -> list[AssetItem]:
with Session(db.engine) as session:

View File

@ -1,26 +1,47 @@
"""Service for exporting and importing App Bundles (DSL + assets).
Bundle structure:
bundle.zip/
{app_name}.yml # DSL file
manifest.json # Asset manifest (required for import)
{app_name}/ # Asset files
folder/file.txt
...
Import flow (sandbox-based):
1. prepare_import: Frontend gets upload URL, stores import_id in Redis
2. Frontend uploads zip to storage
3. confirm_import: Sandbox downloads zip, extracts, uploads assets via presigned URLs
Manifest format (schema_version 1.0):
- app_assets.tree: Full AppAssetFileTree for 100% ID restoration
- files: node_id -> path mapping for file nodes
- integrity.file_count: Basic validation
"""
from __future__ import annotations
import io
import json
import logging
import re
import zipfile
from dataclasses import dataclass
from uuid import uuid4
import yaml
from pydantic import ValidationError
from sqlalchemy.orm import Session
from core.app.entities.app_bundle_entities import (
BUNDLE_DSL_FILENAME_PATTERN,
BUNDLE_MAX_SIZE,
MANIFEST_FILENAME,
BundleExportResult,
BundleFormatError,
ZipSecurityError,
BundleManifest,
)
from core.app_assets.storage import AssetPath
from core.app_bundle import SourceZipExtractor
from core.zip_sandbox import SandboxDownloadItem, ZipSandbox
from core.app_assets.storage import AppAssetStorage, AssetPath, BundleImportZipPath
from core.zip_sandbox import SandboxDownloadItem, SandboxUploadItem, ZipSandbox
from extensions.ext_database import db
from models import Account, App
from extensions.ext_redis import redis_client
from models.account import Account
from models.model import App
from .app_asset_package_service import AppAssetPackageService
from .app_asset_service import AppAssetService
@ -28,6 +49,15 @@ from .app_dsl_service import AppDslService, Import
logger = logging.getLogger(__name__)
_IMPORT_REDIS_PREFIX = "app_bundle:import:"
_IMPORT_TTL_SECONDS = 3600 # 1 hour
@dataclass
class ImportPrepareResult:
import_id: str
upload_url: str
class AppBundleService:
@staticmethod
@ -38,14 +68,10 @@ class AppBundleService:
marked_name: str = "",
marked_comment: str = "",
):
"""
Publish App Bundle (workflow + assets).
Coordinates WorkflowService and AppAssetService publishing in a single transaction.
"""
"""Publish App Bundle (workflow + assets) in a single transaction."""
from models.workflow import Workflow
from services.workflow_service import WorkflowService
# 1. Publish workflow
workflow: Workflow = WorkflowService().publish_workflow(
session=session,
app_model=app_model,
@ -53,17 +79,16 @@ class AppBundleService:
marked_name=marked_name,
marked_comment=marked_comment,
)
# 2. Publish assets (bound to workflow_id)
AppAssetPackageService.publish(
session=session,
app_model=app_model,
account_id=account.id,
workflow_id=workflow.id,
)
return workflow
# ========== Export ==========
@staticmethod
def export_bundle(
*,
@ -73,14 +98,14 @@ class AppBundleService:
workflow_id: str | None = None,
expires_in: int = 10 * 60,
) -> BundleExportResult:
"""Export bundle and return a temporary download URL.
Uses sandbox VM to build the ZIP, avoiding memory pressure in API process.
"""
"""Export bundle with manifest.json and return a temporary download URL."""
tenant_id = app_model.tenant_id
app_id = app_model.id
safe_name = AppBundleService._sanitize_filename(app_model.name)
filename = f"{safe_name}.zip"
dsl_filename = f"{safe_name}.yml"
app_assets = AppAssetService.get_assets_by_version(tenant_id, app_id, workflow_id)
manifest = BundleManifest.from_tree(app_assets.asset_tree, dsl_filename)
export_id = uuid4().hex
export_path = AssetPath.bundle_export_zip(tenant_id, app_id, export_id)
@ -95,147 +120,170 @@ class AppBundleService:
with ZipSandbox(tenant_id=tenant_id, user_id=account_id, app_id="app-bundle-export") as zs:
zs.write_file(f"bundle_root/{safe_name}.yml", dsl_content.encode("utf-8"))
zs.write_file(f"bundle_root/{MANIFEST_FILENAME}", manifest.model_dump_json(indent=2).encode("utf-8"))
# Published assets: use stored source zip and unzip into <safe_name>/...
if workflow_id is not None:
source_zip_path = AssetPath.source_zip(tenant_id, app_id, workflow_id)
source_url = asset_storage.get_download_url(source_zip_path, expires_in)
zs.download_archive(source_url, path="tmp/source_assets.zip")
zs.unzip(archive_path="tmp/source_assets.zip", dest_dir=f"bundle_root/{safe_name}")
else:
# Draft assets: download individual files and place under <safe_name>/...
asset_items = AppAssetService.get_draft_assets(tenant_id, app_id)
asset_urls = asset_storage.get_download_urls(
[AssetPath.draft(tenant_id, app_id, a.asset_id) for a in asset_items], expires_in
)
zs.download_items(
[
SandboxDownloadItem(url=url, path=f"{safe_name}/{a.path}")
for a, url in zip(asset_items, asset_urls, strict=True)
],
dest_dir="bundle_root",
)
if asset_items:
asset_urls = asset_storage.get_download_urls(
[AssetPath.draft(tenant_id, app_id, a.asset_id) for a in asset_items], expires_in
)
zs.download_items(
[
SandboxDownloadItem(url=url, path=f"{safe_name}/{a.path}")
for a, url in zip(asset_items, asset_urls, strict=True)
],
dest_dir="bundle_root",
)
archive = zs.zip(src="bundle_root", include_base=False)
zs.upload(archive, upload_url)
download_url = asset_storage.get_download_url(export_path, expires_in)
return BundleExportResult(download_url=download_url, filename=filename)
return BundleExportResult(download_url=download_url, filename=f"{safe_name}.zip")
# ========== Import ==========
@staticmethod
def import_bundle(
def prepare_import(tenant_id: str, account_id: str) -> ImportPrepareResult:
"""Prepare import: generate import_id and upload URL."""
import_id = uuid4().hex
import_path = AssetPath.bundle_import_zip(tenant_id, import_id)
asset_storage = AppAssetService.get_storage()
upload_url = asset_storage.get_import_upload_url(import_path, _IMPORT_TTL_SECONDS)
redis_client.setex(
f"{_IMPORT_REDIS_PREFIX}{import_id}",
_IMPORT_TTL_SECONDS,
json.dumps({"tenant_id": tenant_id, "account_id": account_id}),
)
return ImportPrepareResult(import_id=import_id, upload_url=upload_url)
@staticmethod
def confirm_import(
import_id: str,
account: Account,
zip_bytes: bytes,
*,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
) -> Import:
if len(zip_bytes) > BUNDLE_MAX_SIZE:
raise BundleFormatError(f"Bundle size exceeds limit: {BUNDLE_MAX_SIZE} bytes")
"""Confirm import: download zip in sandbox, extract, and upload assets."""
redis_key = f"{_IMPORT_REDIS_PREFIX}{import_id}"
redis_data = redis_client.get(redis_key)
if not redis_data:
raise BundleFormatError("Import session expired or not found")
dsl_content, assets_prefix = AppBundleService._extract_dsl_from_bundle(zip_bytes)
import_meta = json.loads(redis_data)
tenant_id: str = import_meta["tenant_id"]
with Session(db.engine) as session:
dsl_service = AppDslService(session)
import_result = dsl_service.import_app(
if tenant_id != account.current_tenant_id:
raise BundleFormatError("Import session tenant mismatch")
import_path = AssetPath.bundle_import_zip(tenant_id, import_id)
asset_storage = AppAssetService.get_storage()
try:
result = AppBundleService.import_bundle(
tenant_id=tenant_id,
account=account,
import_mode="yaml-content",
yaml_content=dsl_content,
import_path=import_path,
asset_storage=asset_storage,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
app_id=None,
)
session.commit()
finally:
redis_client.delete(redis_key)
asset_storage.delete_import_zip(import_path)
if import_result.app_id and assets_prefix:
AppBundleService._import_assets_from_bundle(
zip_bytes=zip_bytes,
assets_prefix=assets_prefix,
app_id=import_result.app_id,
account_id=account.id,
)
return result
@staticmethod
def import_bundle(
*,
tenant_id: str,
account: Account,
import_path: BundleImportZipPath,
asset_storage: AppAssetStorage,
name: str | None,
description: str | None,
icon_type: str | None,
icon: str | None,
icon_background: str | None,
) -> Import:
"""Execute import in sandbox."""
download_url = asset_storage.get_import_download_url(import_path, _IMPORT_TTL_SECONDS)
with ZipSandbox(tenant_id=tenant_id, user_id=account.id, app_id="app-bundle-import") as zs:
zs.download_archive(download_url, path="import.zip")
zs.unzip(archive_path="import.zip", dest_dir="bundle")
manifest_bytes = zs.read_file(f"bundle/{MANIFEST_FILENAME}")
try:
manifest = BundleManifest.model_validate_json(manifest_bytes)
except ValidationError as e:
raise BundleFormatError(f"Invalid manifest.json: {e}") from e
dsl_content = zs.read_file(f"bundle/{manifest.dsl_filename}").decode("utf-8")
with Session(db.engine) as session:
dsl_service = AppDslService(session)
import_result = dsl_service.import_app(
account=account,
import_mode="yaml-content",
yaml_content=dsl_content,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
app_id=None,
)
session.commit()
if not import_result.app_id:
return import_result
app_id = import_result.app_id
tree = manifest.app_assets.tree
upload_items: list[SandboxUploadItem] = []
for file_entry in manifest.files:
asset_path = AssetPath.draft(tenant_id, app_id, file_entry.node_id)
file_upload_url = asset_storage.get_upload_url(asset_path, _IMPORT_TTL_SECONDS)
src_path = f"{manifest.assets_prefix}/{file_entry.path}"
upload_items.append(SandboxUploadItem(path=src_path, url=file_upload_url))
if upload_items:
zs.upload_items(upload_items, src_dir="bundle")
# Tree sizes are already set from manifest; no need to update
app_model = db.session.query(App).filter(App.id == app_id).first()
if app_model:
AppAssetService.set_draft_assets(
app_model=app_model,
account_id=account.id,
new_tree=tree,
)
return import_result
@staticmethod
def _extract_dsl_from_bundle(zip_bytes: bytes) -> tuple[str, str | None]:
dsl_content: str | None = None
dsl_filename: str | None = None
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
for info in zf.infolist():
if info.is_dir():
continue
if BUNDLE_DSL_FILENAME_PATTERN.match(info.filename):
if dsl_content is not None:
raise BundleFormatError("Multiple DSL files found in bundle")
dsl_content = zf.read(info).decode("utf-8")
dsl_filename = info.filename
if dsl_content is None or dsl_filename is None:
raise BundleFormatError("No DSL file (*.yml or *.yaml) found in bundle root")
yaml.safe_load(dsl_content)
assets_prefix = dsl_filename.rsplit(".", 1)[0]
has_assets = AppBundleService._check_assets_prefix_exists(zip_bytes, assets_prefix)
return dsl_content, assets_prefix if has_assets else None
@staticmethod
def _check_assets_prefix_exists(zip_bytes: bytes, prefix: str) -> bool:
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
for info in zf.infolist():
if info.filename.startswith(f"{prefix}/"):
return True
return False
@staticmethod
def _import_assets_from_bundle(
zip_bytes: bytes,
assets_prefix: str,
app_id: str,
account_id: str,
) -> None:
app_model = db.session.query(App).filter(App.id == app_id).first()
if not app_model:
logger.warning("App not found for asset import: %s", app_id)
return
asset_storage = AppAssetService.get_storage()
extractor = SourceZipExtractor(asset_storage)
try:
folders, files = extractor.extract_entries(
zip_bytes,
expected_prefix=f"{assets_prefix}/",
)
except ZipSecurityError as e:
logger.warning("Zip security error during asset import: %s", e)
return
if not folders and not files:
return
new_tree = extractor.build_tree_and_save(
folders=folders,
files=files,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
)
AppAssetService.set_draft_assets(
app_model=app_model,
account_id=account_id,
new_tree=new_tree,
)
# ========== Helpers ==========
@staticmethod
def _sanitize_filename(name: str) -> str:
"""Sanitize app name for use as filename."""
safe = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", name)
safe = safe.strip(". ")
return safe[:100] if safe else "app"

View File

@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.regenerate_summary_index_task import regenerate_summary_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
@ -211,6 +212,7 @@ class DatasetService:
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
summary_index_setting: dict | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
@ -253,6 +255,8 @@ class DatasetService:
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
if summary_index_setting is not None:
dataset.summary_index_setting = summary_index_setting
db.session.add(dataset)
db.session.flush()
@ -476,6 +480,11 @@ class DatasetService:
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
# Update summary index setting if provided
summary_index_setting = data.get("summary_index_setting", None)
if summary_index_setting is not None:
dataset.summary_index_setting = summary_index_setting
# Update basic dataset properties
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", dataset.description)
@ -564,6 +573,9 @@ class DatasetService:
# update Retrieval model
if data.get("retrieval_model"):
filtered_data["retrieval_model"] = data["retrieval_model"]
# update summary index setting
if data.get("summary_index_setting"):
filtered_data["summary_index_setting"] = data.get("summary_index_setting")
# update icon info
if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info")
@ -572,12 +584,27 @@ class DatasetService:
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
# Reload dataset to get updated values
db.session.refresh(dataset)
# update pipeline knowledge base node data
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
# If embedding_model changed, also regenerate summary vectors
if action == "update":
regenerate_summary_index_task.delay(
dataset.id,
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
# Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries.
# The new setting will only apply to:
# 1. New documents added after the setting change
# 2. Manual summary generation requests
return dataset
@ -616,6 +643,7 @@ class DatasetService:
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
node["data"] = knowledge_index_node_data
updated = True
except Exception:
@ -854,6 +882,54 @@ class DatasetService:
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
@staticmethod
def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
"""
Check if summary_index_setting model (model_name or model_provider_name) has changed.
Args:
dataset: Current dataset object
data: Update data dictionary
Returns:
bool: True if summary model changed, False otherwise
"""
# Check if summary_index_setting is being updated
if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
return False
new_summary_setting = data.get("summary_index_setting")
old_summary_setting = dataset.summary_index_setting
# If new setting is disabled, no need to regenerate
if not new_summary_setting or not new_summary_setting.get("enable"):
return False
# If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate)
# Note: This task only regenerates existing summaries, not generates new ones
if not old_summary_setting:
return False
# Compare model_name and model_provider_name
old_model_name = old_summary_setting.get("model_name")
old_model_provider = old_summary_setting.get("model_provider_name")
new_model_name = new_summary_setting.get("model_name")
new_model_provider = new_summary_setting.get("model_provider_name")
# Check if model changed
if old_model_name != new_model_name or old_model_provider != new_model_provider:
logger.info(
"Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
dataset.id,
old_model_provider,
old_model_name,
new_model_provider,
new_model_name,
)
return True
return False
@staticmethod
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
@ -889,6 +965,9 @@ class DatasetService:
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
@ -994,6 +1073,9 @@ class DatasetService:
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
session.add(dataset)
session.commit()
if action:
@ -1314,6 +1396,50 @@ class DocumentService:
upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
@staticmethod
def enrich_documents_with_summary_index_status(
documents: Sequence[Document],
dataset: Dataset,
tenant_id: str,
) -> None:
"""
Enrich documents with summary_index_status based on dataset summary index settings.
This method calculates and sets the summary_index_status for each document that needs summary.
Documents that don't need summary or when summary index is disabled will have status set to None.
Args:
documents: List of Document instances to enrich
dataset: Dataset instance containing summary_index_setting
tenant_id: Tenant ID for summary status lookup
"""
# Check if dataset has summary index enabled
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
# Filter documents that need summary calculation
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
summary_status_map: dict[str, str | None] = {}
if has_summary_index and document_ids_need_summary:
from services.summary_index_service import SummaryIndexService
summary_status_map = SummaryIndexService.get_documents_summary_index_status(
document_ids=document_ids_need_summary,
dataset_id=dataset.id,
tenant_id=tenant_id,
)
# Add summary_index_status to each document
for document in documents:
if has_summary_index and document.need_summary is True:
# Get status from map, default to None (not queued yet)
document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined]
else:
# Return null if summary index is not enabled or document doesn't need summary
document.summary_index_status = None # type: ignore[attr-defined]
@staticmethod
def prepare_document_batch_download_zip(
*,
@ -1964,6 +2090,8 @@ class DocumentService:
DuplicateDocumentIndexingTaskProxy(
dataset.tenant_id, dataset.id, duplicate_document_ids
).delay()
# Note: Summary index generation is triggered in document_indexing_task after indexing completes
# to ensure segments are available. See tasks/document_indexing_task.py
except LockNotOwnedError:
pass
@ -2268,6 +2396,11 @@ class DocumentService:
name: str,
batch: str,
):
# Set need_summary based on dataset's summary_index_setting
need_summary = False
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
need_summary = True
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@ -2281,6 +2414,7 @@ class DocumentService:
created_by=account.id,
doc_form=document_form,
doc_language=document_language,
need_summary=need_summary,
)
doc_metadata = {}
if dataset.built_in_field_enabled:
@ -2505,6 +2639,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
summary_index_setting=knowledge_config.summary_index_setting,
is_multimodal=knowledge_config.is_multimodal,
)
@ -2686,6 +2821,14 @@ class DocumentService:
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
# valid summary index setting
summary_index_setting = args["process_rule"].get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable"):
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
raise ValueError("Summary index model name is required")
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
raise ValueError("Summary index model provider name is required")
@staticmethod
def batch_update_document_status(
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
@ -3154,6 +3297,35 @@ class SegmentService:
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update summary index if summary is provided and has changed
if args.summary is not None:
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
if dataset.indexing_technique == "high_quality":
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
)
# Check if summary has changed
existing_summary_content = existing_summary.summary_content if existing_summary else None
if existing_summary_content != args.summary:
# Summary has changed, update it
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
except Exception:
logger.exception("Failed to update summary for segment %s", segment.id)
# Don't fail the entire update if summary update fails
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@ -3228,6 +3400,73 @@ class SegmentService:
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle summary index when content changed
if dataset.indexing_technique == "high_quality":
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
)
if args.summary is None:
# User didn't provide summary, auto-regenerate if segment previously had summary
# Auto-regeneration only happens if summary_index_setting exists and enable is True
if (
existing_summary
and dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
):
# Segment previously had summary, regenerate it with new content
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, dataset.summary_index_setting
)
logger.info("Auto-regenerated summary for segment %s after content change", segment.id)
except Exception:
logger.exception("Failed to auto-regenerate summary for segment %s", segment.id)
# Don't fail the entire update if summary regeneration fails
else:
# User provided summary, check if it has changed
# Manual summary updates are allowed even if summary_index_setting doesn't exist
existing_summary_content = existing_summary.summary_content if existing_summary else None
if existing_summary_content != args.summary:
# Summary has changed, use user-provided summary
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
logger.info("Updated summary for segment %s with user-provided content", segment.id)
except Exception:
logger.exception("Failed to update summary for segment %s", segment.id)
# Don't fail the entire update if summary update fails
else:
# Summary hasn't changed, regenerate based on new content
# Auto-regeneration only happens if summary_index_setting exists and enable is True
if (
existing_summary
and dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
):
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, dataset.summary_index_setting
)
logger.info(
"Regenerated summary for segment %s after content change (summary unchanged)",
segment.id,
)
except Exception:
logger.exception("Failed to regenerate summary for segment %s", segment.id)
# Don't fail the entire update if summary regeneration fails
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
@ -3616,6 +3855,39 @@ class SegmentService:
)
return result if isinstance(result, DocumentSegment) else None
@classmethod
def get_segments_by_document_and_dataset(
cls,
document_id: str,
dataset_id: str,
status: str | None = None,
enabled: bool | None = None,
) -> Sequence[DocumentSegment]:
"""
Get segments for a document in a dataset with optional filtering.
Args:
document_id: Document ID
dataset_id: Dataset ID
status: Optional status filter (e.g., "completed")
enabled: Optional enabled filter (True/False)
Returns:
Sequence of DocumentSegment instances
"""
query = select(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
if status is not None:
query = query.where(DocumentSegment.status == status)
if enabled is not None:
query = query.where(DocumentSegment.enabled == enabled)
return db.session.scalars(query).all()
class DatasetCollectionBindingService:
@classmethod

View File

@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel):
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
summary_index_setting: dict | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: str | None = None
@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel):
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class ChildChunkUpdateArgs(BaseModel):

View File

@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel):
embedding_model: str = ""
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
# add summary index setting
summary_index_setting: dict | None = None
@field_validator("embedding_model_provider", mode="before")
@classmethod

View File

@ -343,6 +343,9 @@ class RagPipelineDslService:
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
@ -477,6 +480,9 @@ class RagPipelineDslService:
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()

View File

@ -10,6 +10,7 @@ from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssets
from core.sandbox.initializer.skill_initializer import SkillInitializer
from core.sandbox.sandbox import Sandbox
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
from extensions.ext_storage import storage
from services.app_asset_package_service import AppAssetPackageService
from services.app_asset_service import AppAssetService
@ -30,7 +31,7 @@ class SandboxService:
if not assets:
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id)
archive_storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id, storage.storage_runner)
sandbox = (
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
.options(sandbox_provider.config)
@ -40,7 +41,7 @@ class SandboxService:
.initializer(AppAssetsInitializer(tenant_id, app_id, assets.id))
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
.storage(storage, assets.id)
.storage(archive_storage, assets.id)
.build()
)
@ -49,8 +50,8 @@ class SandboxService:
@classmethod
def delete_draft_storage(cls, tenant_id: str, user_id: str) -> None:
storage = ArchiveSandboxStorage(tenant_id, SandboxBuilder.draft_id(user_id))
storage.delete()
archive_storage = ArchiveSandboxStorage(tenant_id, SandboxBuilder.draft_id(user_id), storage.storage_runner)
archive_storage.delete()
@classmethod
def create_draft(
@ -66,7 +67,9 @@ class SandboxService:
AppAssetPackageService.build_assets(tenant_id, app_id, assets)
sandbox_id = SandboxBuilder.draft_id(user_id)
storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH])
archive_storage = ArchiveSandboxStorage(
tenant_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
)
sandbox = (
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
@ -77,7 +80,7 @@ class SandboxService:
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
.storage(storage, assets.id)
.storage(archive_storage, assets.id)
.build()
)
@ -98,7 +101,9 @@ class SandboxService:
AppAssetPackageService.build_assets(tenant_id, app_id, assets)
sandbox_id = SandboxBuilder.draft_id(user_id)
storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH])
archive_storage = ArchiveSandboxStorage(
tenant_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
)
sandbox = (
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
@ -109,7 +114,7 @@ class SandboxService:
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
.storage(storage, assets.id)
.storage(archive_storage, assets.id)
.build()
)

View File

@ -0,0 +1,159 @@
"""Storage ticket service for generating opaque download/upload URLs.
This service provides a ticket-based approach for file access. Instead of exposing
the real storage key in URLs, it generates a random UUID token and stores the mapping
in Redis with a TTL.
Usage:
from services.storage_ticket_service import StorageTicketService
# Generate a download ticket
url = StorageTicketService.create_download_url("path/to/file.txt", expires_in=300)
# Generate an upload ticket
url = StorageTicketService.create_upload_url("path/to/file.txt", expires_in=300, max_bytes=10*1024*1024)
URL format:
{FILES_URL}/files/storage-tickets/{token}
The token is validated by looking up the Redis key, which contains:
- op: "download" or "upload"
- storage_key: the real storage path
- max_bytes: (upload only) maximum allowed upload size
- filename: suggested filename for Content-Disposition header
"""
import json
import logging
from dataclasses import dataclass
from uuid import uuid4
from configs import dify_config
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
TICKET_KEY_PREFIX = "storage_files"
DEFAULT_DOWNLOAD_TTL = 300 # 5 minutes
DEFAULT_UPLOAD_TTL = 300 # 5 minutes
DEFAULT_MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100MB
@dataclass
class StorageTicket:
"""Represents a storage access ticket."""
op: str # "download" or "upload"
storage_key: str
max_bytes: int | None = None # upload only
filename: str | None = None # suggested filename for download
def to_dict(self) -> dict:
data = {"op": self.op, "storage_key": self.storage_key}
if self.max_bytes is not None:
data["max_bytes"] = str(self.max_bytes)
if self.filename is not None:
data["filename"] = self.filename
return data
@classmethod
def from_dict(cls, data: dict) -> "StorageTicket":
return cls(
op=data["op"],
storage_key=data["storage_key"],
max_bytes=data.get("max_bytes"),
filename=data.get("filename"),
)
class StorageTicketService:
"""Service for creating and validating storage access tickets."""
@classmethod
def create_download_url(
cls,
storage_key: str,
*,
expires_in: int = DEFAULT_DOWNLOAD_TTL,
filename: str | None = None,
) -> str:
"""Create a download ticket and return the URL.
Args:
storage_key: The real storage path
expires_in: TTL in seconds (default 300)
filename: Suggested filename for Content-Disposition header
Returns:
Full URL with token
"""
if filename is None:
filename = storage_key.rsplit("/", 1)[-1]
ticket = StorageTicket(op="download", storage_key=storage_key, filename=filename)
token = cls._store_ticket(ticket, expires_in)
return cls._build_url(token)
@classmethod
def create_upload_url(
cls,
storage_key: str,
*,
expires_in: int = DEFAULT_UPLOAD_TTL,
max_bytes: int = DEFAULT_MAX_UPLOAD_BYTES,
) -> str:
"""Create an upload ticket and return the URL.
Args:
storage_key: The real storage path
expires_in: TTL in seconds (default 300)
max_bytes: Maximum allowed upload size in bytes
Returns:
Full URL with token
"""
ticket = StorageTicket(op="upload", storage_key=storage_key, max_bytes=max_bytes)
token = cls._store_ticket(ticket, expires_in)
return cls._build_url(token)
@classmethod
def get_ticket(cls, token: str) -> StorageTicket | None:
"""Retrieve a ticket by token.
Args:
token: The UUID token from the URL
Returns:
StorageTicket if found and valid, None otherwise
"""
key = cls._ticket_key(token)
try:
data = redis_client.get(key)
if data is None:
return None
if isinstance(data, bytes):
data = data.decode("utf-8")
return StorageTicket.from_dict(json.loads(data))
except Exception:
logger.warning("Failed to retrieve storage ticket: %s", token, exc_info=True)
return None
@classmethod
def _store_ticket(cls, ticket: StorageTicket, ttl: int) -> str:
"""Store a ticket in Redis and return the token."""
token = str(uuid4())
key = cls._ticket_key(token)
value = json.dumps(ticket.to_dict())
redis_client.setex(key, ttl, value)
return token
@classmethod
def _ticket_key(cls, token: str) -> str:
"""Generate Redis key for a token."""
return f"{TICKET_KEY_PREFIX}:{token}"
@classmethod
def _build_url(cls, token: str) -> str:
"""Build the full URL for a token."""
base_url = dify_config.FILES_URL
return f"{base_url}/files/storage-files/{token}"

File diff suppressed because it is too large Load Diff

View File

@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str):
)
session.commit()
# Enable summary indexes for all segments in this document
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
if segment_ids_list:
try:
SummaryIndexService.enable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
)
except Exception as e:
logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")

View File

@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)

View File

@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)

View File

@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)

View File

@ -47,6 +47,7 @@ def delete_segment_from_index_task(
doc_form = dataset_document.doc_form
# Proceed with index cleanup using the index_node_ids directly
# For actual deletion, we should delete summaries (not just disable them)
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
@ -54,6 +55,7 @@ def delete_segment_from_index_task(
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
delete_summaries=True, # Actually delete summaries when segment is deleted
)
if dataset.is_multimodal:
# delete segment attachment binding

View File

@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str):
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
# Disable summary index for this segment
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.disable_summaries_for_segments(
dataset=dataset,
segment_ids=[segment.id],
disabled_by=segment.disabled_by,
)
except Exception as e:
logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
end_at = time.perf_counter()
logger.info(
click.style(

View File

@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
# Disable summary indexes for these segments
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
try:
# Get disabled_by from first segment (they should all have the same disabled_by)
disabled_by = segments[0].disabled_by if segments else None
SummaryIndexService.disable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
disabled_by=disabled_by,
)
except Exception as e:
logger.warning("Failed to disable summaries for segments: %s", str(e))
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:

View File

@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset %s not found after indexing", dataset_id)
return
if dataset.indexing_technique == "high_quality":
summary_index_setting = dataset.summary_index_setting
if summary_index_setting and summary_index_setting.get("enable"):
# expire all session to get latest document's indexing status
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it)
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document_id,
document.indexing_status,
document.doc_form,
document.need_summary,
)
if (
document.indexing_status == "completed"
and document.doc_form != "qa_model"
and document.need_summary is True
):
try:
generate_summary_index_task.delay(dataset.id, document_id, None)
logger.info(
"Queued summary index generation task for document %s in dataset %s "
"after indexing completed",
document_id,
dataset.id,
)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document_id,
)
# Don't fail the entire indexing process if summary task queuing fails
else:
logger.info(
"Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s",
document_id,
document.indexing_status,
document.doc_form,
document.need_summary,
)
else:
logger.warning("Document %s not found after indexing", document_id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
else:
logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id,
dataset.indexing_technique,
)
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:

View File

@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str):
# save vector index
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
# Enable summary index for this segment
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.enable_summaries_for_segments(
dataset=dataset,
segment_ids=[segment.id],
)
except Exception as e:
logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:

View File

@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
# save vector index
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# Enable summary indexes for these segments
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
try:
SummaryIndexService.enable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
)
except Exception as e:
logger.warning("Failed to enable summaries for segments: %s", str(e))
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:

Some files were not shown because too many files have changed in this diff Show More