mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
Merge remote-tracking branch 'origin/feat/support-agent-sandbox' into feat/support-agent-sandbox
This commit is contained in:
commit
1a51f52061
@ -480,4 +480,4 @@ const useButtonState = () => {
|
||||
### Related Skills
|
||||
|
||||
- `frontend-testing` - For testing refactored components
|
||||
- `web/testing/testing.md` - Testing specification
|
||||
- `web/docs/test.md` - Testing specification
|
||||
|
||||
@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
|
||||
## When to Apply This Skill
|
||||
|
||||
@ -309,7 +309,7 @@ For more detailed information, refer to:
|
||||
|
||||
### Primary Specification (MUST follow)
|
||||
|
||||
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
|
||||
- **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document.
|
||||
|
||||
### Reference Examples in Codebase
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com
|
||||
|
||||
## Scope Clarification
|
||||
|
||||
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
|
||||
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals.
|
||||
|
||||
| Scope | Rule |
|
||||
|-------|------|
|
||||
|
||||
1
.github/workflows/api-tests.yml
vendored
1
.github/workflows/api-tests.yml
vendored
@ -72,6 +72,7 @@ jobs:
|
||||
OPENDAL_FS_ROOT: /tmp/dify-storage
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-n auto \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@ -47,13 +47,9 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Run Mypy Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
run: make type-check
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
33
AGENTS.md
33
AGENTS.md
@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv
|
||||
The codebase is split into:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
|
||||
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Backend Workflow
|
||||
@ -18,36 +18,7 @@ The codebase is split into:
|
||||
|
||||
## Frontend Workflow
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint:fix
|
||||
pnpm type-check:tsgo
|
||||
pnpm test
|
||||
```
|
||||
|
||||
### Frontend Linting
|
||||
|
||||
ESLint is used for frontend code quality. Available commands:
|
||||
|
||||
```bash
|
||||
# Lint all files (report only)
|
||||
pnpm lint
|
||||
|
||||
# Lint and auto-fix issues
|
||||
pnpm lint:fix
|
||||
|
||||
# Lint specific files or directories
|
||||
pnpm lint:fix app/components/base/button/
|
||||
pnpm lint:fix app/components/base/button/index.tsx
|
||||
|
||||
# Lint quietly (errors only, no warnings)
|
||||
pnpm lint:quiet
|
||||
|
||||
# Check code complexity
|
||||
pnpm lint:complexity
|
||||
```
|
||||
|
||||
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
|
||||
- Read `web/AGENTS.md` for details
|
||||
|
||||
## Testing & Quality Practices
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ How we prioritize:
|
||||
|
||||
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
|
||||
|
||||
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
|
||||
**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there.
|
||||
|
||||
#### Backend
|
||||
|
||||
|
||||
12
Makefile
12
Makefile
@ -68,9 +68,11 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type check with basedpyright..."
|
||||
@uv run --directory api --dev basedpyright
|
||||
@echo "✅ Type check complete"
|
||||
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@cd api && uv run ty check
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@ -78,7 +80,7 @@ test:
|
||||
echo "Target: $(TARGET_TESTS)"; \
|
||||
uv run --project api --dev pytest $(TARGET_TESTS); \
|
||||
else \
|
||||
uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
|
||||
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
|
||||
fi
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
@ -130,7 +132,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@ -620,6 +620,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
|
||||
@ -227,6 +227,9 @@ ignore_imports =
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
@ -300,6 +303,58 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
core.workflow.nodes.tool.tool_node -> services
|
||||
|
||||
[importlinter:contract:model-runtime-no-internal-imports]
|
||||
name = Model Runtime Internal Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.model_runtime
|
||||
forbidden_modules =
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
services
|
||||
tasks
|
||||
core.agent
|
||||
core.app
|
||||
core.base
|
||||
core.callback_handler
|
||||
core.datasource
|
||||
core.db
|
||||
core.entities
|
||||
core.errors
|
||||
core.extension
|
||||
core.external_data_tool
|
||||
core.file
|
||||
core.helper
|
||||
core.hosting_configuration
|
||||
core.indexing_runner
|
||||
core.llm_generator
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
core.schemas
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.model_runtime.model_providers.__base.ai_model -> configs
|
||||
core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
core.model_runtime.model_providers.__base.large_language_model -> configs
|
||||
core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
|
||||
core.model_runtime.model_providers.model_provider_factory -> configs
|
||||
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
||||
@ -243,6 +243,11 @@ class PluginConfig(BaseSettings):
|
||||
default=15728640 * 12,
|
||||
)
|
||||
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
|
||||
description="TTL in seconds for caching plugin model schemas in Redis",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class CliApiConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
|
||||
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
|
||||
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource):
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
content = {
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
}
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
content={
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
},
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
|
||||
@ -51,7 +51,7 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
class AppImportBundlePayload(BaseModel):
|
||||
class AppImportBundleConfirmPayload(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
@ -149,15 +149,38 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports-bundle")
|
||||
class AppImportBundleApi(Resource):
|
||||
@console_ns.route("/apps/imports-bundle/prepare")
|
||||
class AppImportBundlePrepareApi(Resource):
|
||||
"""Step 1: Get upload URL for bundle import."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
from services.app_bundle_service import AppBundleService
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
result = AppBundleService.prepare_import(
|
||||
tenant_id=current_tenant_id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"import_id": result.import_id, "upload_url": result.upload_url}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports-bundle/<string:import_id>/confirm")
|
||||
class AppImportBundleConfirmApi(Resource):
|
||||
"""Step 2: Confirm bundle import after upload."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
def post(self, import_id: str):
|
||||
from flask import request
|
||||
|
||||
from core.app.entities.app_bundle_entities import BundleFormatError
|
||||
@ -165,22 +188,12 @@ class AppImportBundleApi(Resource):
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if "file" not in request.files:
|
||||
return {"error": "No file provided"}, 400
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename or not file.filename.endswith(".zip"):
|
||||
return {"error": "Invalid file format, expected .zip"}, 400
|
||||
|
||||
zip_bytes = file.read()
|
||||
|
||||
form_data = request.form.to_dict()
|
||||
args = AppImportBundlePayload.model_validate(form_data)
|
||||
args = AppImportBundleConfirmPayload.model_validate(request.get_json() or {})
|
||||
|
||||
try:
|
||||
result = AppBundleService.import_bundle(
|
||||
result = AppBundleService.confirm_import(
|
||||
import_id=import_id,
|
||||
account=current_user,
|
||||
zip_bytes=zip_bytes,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
|
||||
@ -70,9 +70,7 @@ class ContextGeneratePayload(BaseModel):
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
code_context: CodeContextPayload = Field(
|
||||
description="Existing code node context for incremental generation"
|
||||
)
|
||||
code_context: CodeContextPayload = Field(description="Existing code node context for incremental generation")
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
|
||||
@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
@ -288,7 +289,14 @@ class DatasetListApi(Resource):
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict())
|
||||
# Convert query parameters to dict, handling list parameters correctly
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
|
||||
if "ids" in request.args:
|
||||
query_params["ids"] = request.args.getlist("ids")
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = ConsoleDatasetListQuery.model_validate(query_params)
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
if query.ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
|
||||
|
||||
@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
@ -125,6 +130,7 @@ register_schema_models(
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
)
|
||||
|
||||
@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
DocumentService.enrich_documents_with_summary_index_status(
|
||||
documents=documents,
|
||||
dataset=dataset,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -797,6 +810,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -832,6 +846,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -1255,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task.delay(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get summary status detail from service
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
result = SummaryIndexService.get_document_summary_status_detail(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
return result, 200
|
||||
|
||||
@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"data": segments_with_summary,
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -14,13 +21,45 @@ from ..wraps import (
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
||||
@ -15,12 +15,8 @@ api = ExternalApi(
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
||||
from . import (
|
||||
app_assets_download,
|
||||
app_assets_upload,
|
||||
image_preview,
|
||||
sandbox_archive,
|
||||
sandbox_file_downloads,
|
||||
storage_download,
|
||||
storage_files,
|
||||
tool_files,
|
||||
upload,
|
||||
)
|
||||
@ -29,14 +25,10 @@ api.add_namespace(files_ns)
|
||||
|
||||
__all__ = [
|
||||
"api",
|
||||
"app_assets_download",
|
||||
"app_assets_upload",
|
||||
"bp",
|
||||
"files_ns",
|
||||
"image_preview",
|
||||
"sandbox_archive",
|
||||
"sandbox_file_downloads",
|
||||
"storage_download",
|
||||
"storage_files",
|
||||
"tool_files",
|
||||
"upload",
|
||||
]
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.app_assets.storage import AppAssetSigner, AssetPath
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppAssetDownloadQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
AppAssetDownloadQuery.__name__,
|
||||
AppAssetDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/download")
|
||||
@files_ns.route(
|
||||
"/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/<string:sub_resource_id>/download"
|
||||
)
|
||||
class AppAssetDownloadApi(Resource):
|
||||
def get(
|
||||
self,
|
||||
asset_type: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
resource_id: str,
|
||||
sub_resource_id: str | None = None,
|
||||
):
|
||||
args = AppAssetDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
asset_path = AssetPath.from_components(
|
||||
asset_type=asset_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
resource_id=resource_id,
|
||||
sub_resource_id=sub_resource_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not AppAssetSigner.verify_download_signature(
|
||||
asset_path=asset_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
storage_key = asset_path.get_storage_key()
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(storage_key)
|
||||
except FileNotFoundError as exc:
|
||||
raise NotFound("File not found") from exc
|
||||
|
||||
encoded_filename = quote(storage_key.split("/")[-1])
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
@ -1,61 +0,0 @@
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.app_assets.storage import AppAssetSigner, AssetPath
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppAssetUploadQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
AppAssetUploadQuery.__name__,
|
||||
AppAssetUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/upload")
|
||||
@files_ns.route(
|
||||
"/app-assets/<string:asset_type>/<string:tenant_id>/<string:app_id>/<string:resource_id>/<string:sub_resource_id>/upload"
|
||||
)
|
||||
class AppAssetUploadApi(Resource):
|
||||
def put(
|
||||
self,
|
||||
asset_type: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
resource_id: str,
|
||||
sub_resource_id: str | None = None,
|
||||
):
|
||||
args = AppAssetUploadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
asset_path = AssetPath.from_components(
|
||||
asset_type=asset_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
resource_id=resource_id,
|
||||
sub_resource_id=sub_resource_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not AppAssetSigner.verify_upload_signature(
|
||||
asset_path=asset_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired upload link")
|
||||
|
||||
content = request.get_data()
|
||||
AppAssetService.get_storage().save(asset_path, content)
|
||||
return Response(status=204)
|
||||
@ -1,76 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxArchiveQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
SandboxArchiveQuery.__name__,
|
||||
SandboxArchiveQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/sandbox-archives/<string:tenant_id>/<string:sandbox_id>/download")
|
||||
class SandboxArchiveDownloadApi(Resource):
|
||||
def get(self, tenant_id: str, sandbox_id: str):
|
||||
args = SandboxArchiveQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
archive_path = SandboxArchivePath(tenant_id=UUID(tenant_id), sandbox_id=UUID(sandbox_id))
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxArchiveSigner.verify_download_signature(
|
||||
archive_path=archive_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(archive_path.get_storage_key())
|
||||
except FileNotFoundError as exc:
|
||||
raise NotFound("Archive not found") from exc
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/gzip",
|
||||
direct_passthrough=True,
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/sandbox-archives/<string:tenant_id>/<string:sandbox_id>/upload")
|
||||
class SandboxArchiveUploadApi(Resource):
|
||||
def put(self, tenant_id: str, sandbox_id: str):
|
||||
args = SandboxArchiveQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
archive_path = SandboxArchivePath(tenant_id=UUID(tenant_id), sandbox_id=UUID(sandbox_id))
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxArchiveSigner.verify_upload_signature(
|
||||
archive_path=archive_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired upload link")
|
||||
|
||||
storage.save(archive_path.get_storage_key(), request.get_data())
|
||||
return Response(status=204)
|
||||
@ -1,96 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxFileDownloadQuery(BaseModel):
|
||||
expires_at: int = Field(..., description="Unix timestamp when the link expires")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
SandboxFileDownloadQuery.__name__,
|
||||
SandboxFileDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route(
|
||||
"/sandbox-file-downloads/<string:tenant_id>/<string:sandbox_id>/<string:export_id>/<path:filename>/download"
|
||||
)
|
||||
class SandboxFileDownloadDownloadApi(Resource):
|
||||
def get(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str):
|
||||
args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
export_path = SandboxFileDownloadPath(
|
||||
tenant_id=UUID(tenant_id),
|
||||
sandbox_id=UUID(sandbox_id),
|
||||
export_id=export_id,
|
||||
filename=filename,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxFileSigner.verify_download_signature(
|
||||
export_path=export_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(export_path.get_storage_key())
|
||||
except FileNotFoundError as exc:
|
||||
raise NotFound("File not found") from exc
|
||||
|
||||
encoded_filename = quote(filename.split("/")[-1])
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route(
|
||||
"/sandbox-file-downloads/<string:tenant_id>/<string:sandbox_id>/<string:export_id>/<path:filename>/upload"
|
||||
)
|
||||
class SandboxFileDownloadUploadApi(Resource):
|
||||
def put(self, tenant_id: str, sandbox_id: str, export_id: str, filename: str):
|
||||
args = SandboxFileDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
export_path = SandboxFileDownloadPath(
|
||||
tenant_id=UUID(tenant_id),
|
||||
sandbox_id=UUID(sandbox_id),
|
||||
export_id=export_id,
|
||||
filename=filename,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise Forbidden(str(exc)) from exc
|
||||
|
||||
if not SandboxFileSigner.verify_upload_signature(
|
||||
export_path=export_path,
|
||||
expires_at=args.expires_at,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired upload link")
|
||||
|
||||
storage.save(export_path.get_storage_key(), request.get_data())
|
||||
return Response(status=204)
|
||||
@ -1,56 +0,0 @@
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StorageDownloadQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp used in the signature")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
StorageDownloadQuery.__name__,
|
||||
StorageDownloadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/storage/<path:filename>/download")
|
||||
class StorageFileDownloadApi(Resource):
|
||||
def get(self, filename: str):
|
||||
filename = unquote(filename)
|
||||
|
||||
args = StorageDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
if not FilePresignStorage.verify_signature(
|
||||
filename=filename,
|
||||
timestamp=args.timestamp,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
):
|
||||
raise Forbidden("Invalid or expired download link")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(filename)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
encoded_filename = quote(filename.split("/")[-1])
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
80
api/controllers/files/storage_files.py
Normal file
80
api/controllers/files/storage_files.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""Token-based file proxy controller for storage operations.
|
||||
|
||||
This controller handles file download and upload operations using opaque UUID tokens.
|
||||
The token maps to the real storage key in Redis, so the actual storage path is never
|
||||
exposed in the URL.
|
||||
|
||||
Routes:
|
||||
GET /files/storage-files/{token} - Download a file
|
||||
PUT /files/storage-files/{token} - Upload a file
|
||||
|
||||
The operation type (download/upload) is determined by the ticket stored in Redis,
|
||||
not by the HTTP method. This ensures a download ticket cannot be used for upload
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
|
||||
@files_ns.route("/storage-files/<string:token>")
|
||||
class StorageFilesApi(Resource):
|
||||
"""Handle file operations through token-based URLs."""
|
||||
|
||||
def get(self, token: str):
|
||||
"""Download a file using a token.
|
||||
|
||||
The ticket must have op="download", otherwise returns 403.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "download":
|
||||
raise Forbidden("This token is not valid for download")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(ticket.storage_key)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
def put(self, token: str):
|
||||
"""Upload a file using a token.
|
||||
|
||||
The ticket must have op="upload", otherwise returns 403.
|
||||
If the request body exceeds max_bytes, returns 413.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "upload":
|
||||
raise Forbidden("This token is not valid for upload")
|
||||
|
||||
content = request.get_data()
|
||||
|
||||
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
|
||||
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
|
||||
|
||||
storage.save(ticket.storage_key, content)
|
||||
|
||||
return Response(status=204)
|
||||
@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel):
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
|
||||
class DatasetUpdatePayload(BaseModel):
|
||||
@ -217,6 +218,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
embedding_model_provider=payload.embedding_model_provider,
|
||||
embedding_model_name=payload.embedding_model,
|
||||
retrieval_model=payload.retrieval_model,
|
||||
summary_index_setting=payload.summary_index_setting,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
Segmentation,
|
||||
)
|
||||
from services.file_service import FileService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class DocumentTextCreatePayload(BaseModel):
|
||||
@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource):
|
||||
)
|
||||
documents = paginated_documents.items
|
||||
|
||||
DocumentService.enrich_documents_with_summary_index_status(
|
||||
documents=documents,
|
||||
dataset=dataset,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(documents, document_fields),
|
||||
"has_more": len(documents) == query_params.limit,
|
||||
@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource):
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
# Calculate summary_index_status if needed
|
||||
summary_index_status = None
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
if has_summary_index and document.need_summary is True:
|
||||
summary_index_status = SummaryIndexService.get_document_summary_index_status(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"summary_index_status": summary_index_status,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"summary_index_status": summary_index_status,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"content": resource["content"],
|
||||
"summary": resource.get("summary"),
|
||||
}
|
||||
)
|
||||
metadata["retriever_resources"] = updated_resources
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
|
||||
# Constants
|
||||
BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$")
|
||||
BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
MANIFEST_FILENAME = "manifest.json"
|
||||
MANIFEST_SCHEMA_VERSION = "1.0"
|
||||
|
||||
|
||||
# Exceptions
|
||||
@ -22,21 +27,70 @@ class ZipSecurityError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Entities
|
||||
# Manifest DTOs
|
||||
class ManifestFileEntry(BaseModel):
|
||||
"""Maps node_id to file path in the bundle."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
node_id: str
|
||||
path: str
|
||||
|
||||
|
||||
class ManifestIntegrity(BaseModel):
|
||||
"""Basic integrity check fields."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
file_count: int
|
||||
|
||||
|
||||
class ManifestAppAssets(BaseModel):
|
||||
"""App assets section containing the full tree."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tree: AppAssetFileTree
|
||||
|
||||
|
||||
class BundleManifest(BaseModel):
|
||||
"""
|
||||
Bundle manifest for app asset import/export.
|
||||
|
||||
Schema version 1.0:
|
||||
- dsl_filename: DSL file name in bundle root (e.g. "my_app.yml")
|
||||
- tree: Full AppAssetFileTree (files + folders) for 100% restoration including node IDs
|
||||
- files: Explicit node_id -> path mapping for file nodes only
|
||||
- integrity: Basic file_count validation
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
schema_version: str = Field(default=MANIFEST_SCHEMA_VERSION)
|
||||
generated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC))
|
||||
dsl_filename: str = Field(description="DSL file name in bundle root")
|
||||
app_assets: ManifestAppAssets
|
||||
files: list[ManifestFileEntry]
|
||||
integrity: ManifestIntegrity
|
||||
|
||||
@property
|
||||
def assets_prefix(self) -> str:
|
||||
"""Assets directory name (DSL filename without extension)."""
|
||||
return self.dsl_filename.rsplit(".", 1)[0]
|
||||
|
||||
@classmethod
|
||||
def from_tree(cls, tree: AppAssetFileTree, dsl_filename: str) -> BundleManifest:
|
||||
"""Build manifest from an AppAssetFileTree."""
|
||||
files = [ManifestFileEntry(node_id=n.id, path=tree.get_path(n.id)) for n in tree.walk_files()]
|
||||
return cls(
|
||||
dsl_filename=dsl_filename,
|
||||
app_assets=ManifestAppAssets(tree=tree),
|
||||
files=files,
|
||||
integrity=ManifestIntegrity(file_count=len(files)),
|
||||
)
|
||||
|
||||
|
||||
# Export result
|
||||
class BundleExportResult(BaseModel):
|
||||
download_url: str = Field(description="Temporary download URL for the ZIP")
|
||||
filename: str = Field(description="Suggested filename for the ZIP")
|
||||
|
||||
|
||||
class SourceFileEntry(BaseModel):
|
||||
path: str = Field(description="File path within the ZIP")
|
||||
node_id: str = Field(description="Node ID in the asset tree")
|
||||
|
||||
|
||||
class ExtractedFile(BaseModel):
|
||||
path: str = Field(description="Relative path of the extracted file")
|
||||
content: bytes = Field(description="File content as bytes")
|
||||
|
||||
|
||||
class ExtractedFolder(BaseModel):
|
||||
path: str = Field(description="Relative path of the extracted folder")
|
||||
|
||||
@ -1,25 +1,31 @@
|
||||
"""App assets storage layer.
|
||||
|
||||
This module provides storage abstractions for app assets (draft files, build zips,
|
||||
resolved assets, skill bundles, source zips, bundle exports/imports).
|
||||
|
||||
Key components:
|
||||
- AssetPath: Factory for creating typed storage paths
|
||||
- AppAssetStorage: High-level storage operations with presign support
|
||||
|
||||
All presign operations use the unified FilePresignStorage wrapper, which automatically
|
||||
falls back to Dify's file proxy when the underlying storage doesn't support presigned URLs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Generator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from libs import rsa
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
_ASSET_BASE = "app_assets"
|
||||
_SILENT_STORAGE_NOT_FOUND = b"File Not Found"
|
||||
_ASSET_PATH_REGISTRY: dict[str, tuple[bool, Callable[..., SignedAssetPath]]] = {}
|
||||
_ASSET_PATH_REGISTRY: dict[str, tuple[bool, Any]] = {}
|
||||
|
||||
|
||||
def _require_uuid(value: str, field_name: str) -> None:
|
||||
@ -29,12 +35,14 @@ def _require_uuid(value: str, field_name: str) -> None:
|
||||
raise ValueError(f"{field_name} must be a UUID") from exc
|
||||
|
||||
|
||||
def register_asset_path(asset_type: str, *, requires_node: bool, factory: Callable[..., SignedAssetPath]) -> None:
|
||||
def register_asset_path(asset_type: str, *, requires_node: bool, factory: Any) -> None:
|
||||
_ASSET_PATH_REGISTRY[asset_type] = (requires_node, factory)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetPathBase(ABC):
|
||||
"""Base class for all asset paths."""
|
||||
|
||||
asset_type: ClassVar[str]
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
@ -50,49 +58,24 @@ class AssetPathBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SignedAssetPath(AssetPathBase, ABC):
|
||||
@abstractmethod
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
"""Return (resource_id, sub_resource_id) used for signing.
|
||||
|
||||
sub_resource_id should be None when not applicable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _DraftAssetPath(SignedAssetPath):
|
||||
class _DraftAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "draft"
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/draft/{self.resource_id}"
|
||||
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
return (self.resource_id, None)
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _BuildZipAssetPath(SignedAssetPath):
|
||||
class _BuildZipAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "build-zip"
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/artifacts/{self.resource_id}.zip"
|
||||
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
return (self.resource_id, None)
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ResolvedAssetPath(SignedAssetPath):
|
||||
class _ResolvedAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "resolved"
|
||||
node_id: str
|
||||
|
||||
@ -103,80 +86,76 @@ class _ResolvedAssetPath(SignedAssetPath):
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/artifacts/{self.resource_id}/resolved/{self.node_id}"
|
||||
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
return (self.resource_id, self.node_id)
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id, self.node_id]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SkillBundleAssetPath(SignedAssetPath):
|
||||
class _SkillBundleAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "skill-bundle"
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/artifacts/{self.resource_id}/skill_artifact_set.json"
|
||||
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
return (self.resource_id, None)
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SourceZipAssetPath(SignedAssetPath):
|
||||
class _SourceZipAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "source-zip"
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/sources/{self.resource_id}.zip"
|
||||
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
return (self.resource_id, None)
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _BundleExportZipAssetPath(SignedAssetPath):
|
||||
class _BundleExportZipAssetPath(AssetPathBase):
|
||||
asset_type: ClassVar[str] = "bundle-export-zip"
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/{self.app_id}/bundle_exports/{self.resource_id}.zip"
|
||||
|
||||
def signature_parts(self) -> tuple[str, str | None]:
|
||||
return (self.resource_id, None)
|
||||
|
||||
def proxy_path_parts(self) -> list[str]:
|
||||
return [self.asset_type, self.tenant_id, self.app_id, self.resource_id]
|
||||
@dataclass(frozen=True)
|
||||
class BundleImportZipPath:
|
||||
"""Path for temporary import zip files."""
|
||||
|
||||
tenant_id: str
|
||||
import_id: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
_require_uuid(self.tenant_id, "tenant_id")
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"{_ASSET_BASE}/{self.tenant_id}/imports/{self.import_id}.zip"
|
||||
|
||||
|
||||
class AssetPath:
|
||||
"""Factory for creating typed asset paths."""
|
||||
|
||||
@staticmethod
|
||||
def draft(tenant_id: str, app_id: str, node_id: str) -> SignedAssetPath:
|
||||
def draft(tenant_id: str, app_id: str, node_id: str) -> AssetPathBase:
|
||||
return _DraftAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=node_id)
|
||||
|
||||
@staticmethod
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> SignedAssetPath:
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> AssetPathBase:
|
||||
return _BuildZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id)
|
||||
|
||||
@staticmethod
|
||||
def resolved(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> SignedAssetPath:
|
||||
def resolved(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> AssetPathBase:
|
||||
return _ResolvedAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id, node_id=node_id)
|
||||
|
||||
@staticmethod
|
||||
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> SignedAssetPath:
|
||||
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> AssetPathBase:
|
||||
return _SkillBundleAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=assets_id)
|
||||
|
||||
@staticmethod
|
||||
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> SignedAssetPath:
|
||||
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> AssetPathBase:
|
||||
return _SourceZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def bundle_export_zip(tenant_id: str, app_id: str, export_id: str) -> SignedAssetPath:
|
||||
def bundle_export_zip(tenant_id: str, app_id: str, export_id: str) -> AssetPathBase:
|
||||
return _BundleExportZipAssetPath(tenant_id=tenant_id, app_id=app_id, resource_id=export_id)
|
||||
|
||||
@staticmethod
|
||||
def bundle_import_zip(tenant_id: str, import_id: str) -> BundleImportZipPath:
|
||||
return BundleImportZipPath(tenant_id=tenant_id, import_id=import_id)
|
||||
|
||||
@staticmethod
|
||||
def from_components(
|
||||
asset_type: str,
|
||||
@ -184,7 +163,7 @@ class AssetPath:
|
||||
app_id: str,
|
||||
resource_id: str,
|
||||
sub_resource_id: str | None = None,
|
||||
) -> SignedAssetPath:
|
||||
) -> AssetPathBase:
|
||||
entry = _ASSET_PATH_REGISTRY.get(asset_type)
|
||||
if not entry:
|
||||
raise ValueError(f"Unsupported asset type: {asset_type}")
|
||||
@ -206,120 +185,26 @@ register_asset_path("source-zip", requires_node=False, factory=AssetPath.source_
|
||||
register_asset_path("bundle-export-zip", requires_node=False, factory=AssetPath.bundle_export_zip)
|
||||
|
||||
|
||||
class AppAssetSigner:
|
||||
SIGNATURE_PREFIX = "app-asset"
|
||||
SIGNATURE_VERSION = "v1"
|
||||
OPERATION_DOWNLOAD = "download"
|
||||
OPERATION_UPLOAD = "upload"
|
||||
|
||||
@classmethod
|
||||
def create_download_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str) -> str:
|
||||
return cls._create_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_upload_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str) -> str:
|
||||
return cls._create_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_download_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str, sign: str) -> bool:
|
||||
return cls._verify_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_upload_signature(cls, asset_path: SignedAssetPath, expires_at: int, nonce: str, sign: str) -> bool:
|
||||
return cls._verify_signature(
|
||||
asset_path=asset_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _verify_signature(
|
||||
cls,
|
||||
*,
|
||||
asset_path: SignedAssetPath,
|
||||
operation: str,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
if expires_at <= 0:
|
||||
return False
|
||||
|
||||
expected_sign = cls._create_signature(
|
||||
asset_path=asset_path,
|
||||
operation=operation,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
if not hmac.compare_digest(sign, expected_sign):
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
if expires_at < current_time:
|
||||
return False
|
||||
|
||||
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _create_signature(cls, *, asset_path: SignedAssetPath, operation: str, expires_at: int, nonce: str) -> str:
|
||||
key = cls._tenant_key(asset_path.tenant_id)
|
||||
message = cls._signature_message(
|
||||
asset_path=asset_path,
|
||||
operation=operation,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
sign = hmac.new(key, message.encode(), hashlib.sha256).digest()
|
||||
return base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@classmethod
|
||||
def _signature_message(cls, *, asset_path: SignedAssetPath, operation: str, expires_at: int, nonce: str) -> str:
|
||||
resource_id, sub_resource_id = asset_path.signature_parts()
|
||||
return (
|
||||
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
|
||||
f"{asset_path.asset_type}|{asset_path.tenant_id}|{asset_path.app_id}|"
|
||||
f"{resource_id}|{sub_resource_id or ''}|{expires_at}|{nonce}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _tenant_key(cls, tenant_id: str) -> bytes:
|
||||
try:
|
||||
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
|
||||
except rsa.PrivkeyNotFoundError as exc:
|
||||
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
|
||||
private_key = rsa_key.export_key()
|
||||
return hashlib.sha256(private_key).digest()
|
||||
|
||||
|
||||
class AppAssetStorage:
|
||||
_base_storage: BaseStorage
|
||||
"""High-level storage operations for app assets.
|
||||
|
||||
Wraps BaseStorage with:
|
||||
- FilePresignStorage for presign fallback support
|
||||
- CachedPresignStorage for URL caching
|
||||
|
||||
Usage:
|
||||
storage = AppAssetStorage(base_storage, redis_client=redis)
|
||||
storage.save(asset_path, content)
|
||||
url = storage.get_download_url(asset_path)
|
||||
"""
|
||||
|
||||
_storage: CachedPresignStorage
|
||||
|
||||
def __init__(self, storage: BaseStorage, *, redis_client: Any, cache_key_prefix: str = "app_assets") -> None:
|
||||
self._base_storage = storage
|
||||
# Wrap with FilePresignStorage for fallback support, then CachedPresignStorage for caching
|
||||
presign_storage = FilePresignStorage(storage)
|
||||
self._storage = CachedPresignStorage(
|
||||
storage=storage,
|
||||
storage=presign_storage,
|
||||
redis_client=redis_client,
|
||||
cache_key_prefix=cache_key_prefix,
|
||||
)
|
||||
@ -329,87 +214,51 @@ class AppAssetStorage:
|
||||
return self._storage
|
||||
|
||||
def save(self, asset_path: AssetPathBase, content: bytes) -> None:
|
||||
self._storage.save(self.get_storage_key(asset_path), content)
|
||||
self._storage.save(asset_path.get_storage_key(), content)
|
||||
|
||||
def load(self, asset_path: AssetPathBase) -> bytes:
|
||||
return self._storage.load_once(self.get_storage_key(asset_path))
|
||||
return self._storage.load_once(asset_path.get_storage_key())
|
||||
|
||||
def load_stream(self, asset_path: AssetPathBase) -> Generator[bytes, None, None]:
|
||||
return self._storage.load_stream(asset_path.get_storage_key())
|
||||
|
||||
def load_or_none(self, asset_path: AssetPathBase) -> bytes | None:
|
||||
try:
|
||||
data = self._storage.load_once(self.get_storage_key(asset_path))
|
||||
data = self._storage.load_once(asset_path.get_storage_key())
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
if data == _SILENT_STORAGE_NOT_FOUND:
|
||||
return None
|
||||
return data
|
||||
|
||||
def exists(self, asset_path: AssetPathBase) -> bool:
|
||||
return self._storage.exists(asset_path.get_storage_key())
|
||||
|
||||
def delete(self, asset_path: AssetPathBase) -> None:
|
||||
self._storage.delete(self.get_storage_key(asset_path))
|
||||
self._storage.delete(asset_path.get_storage_key())
|
||||
|
||||
def get_storage_key(self, asset_path: AssetPathBase) -> str:
|
||||
return asset_path.get_storage_key()
|
||||
def get_download_url(self, asset_path: AssetPathBase, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_download_url(asset_path.get_storage_key(), expires_in)
|
||||
|
||||
def get_download_url(self, asset_path: SignedAssetPath, expires_in: int = 3600) -> str:
|
||||
storage_key = self.get_storage_key(asset_path)
|
||||
def get_download_urls(self, asset_paths: Iterable[AssetPathBase], expires_in: int = 3600) -> list[str]:
|
||||
storage_keys = [p.get_storage_key() for p in asset_paths]
|
||||
return self._storage.get_download_urls(storage_keys, expires_in)
|
||||
|
||||
def get_upload_url(self, asset_path: AssetPathBase, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_upload_url(asset_path.get_storage_key(), expires_in)
|
||||
|
||||
# Bundle import convenience methods
|
||||
def get_import_upload_url(self, path: BundleImportZipPath, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_upload_url(path.get_storage_key(), expires_in)
|
||||
|
||||
def get_import_download_url(self, path: BundleImportZipPath, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_download_url(path.get_storage_key(), expires_in)
|
||||
|
||||
def delete_import_zip(self, path: BundleImportZipPath) -> None:
|
||||
"""Delete import zip file. Errors are logged but not raised."""
|
||||
try:
|
||||
return self._storage.get_download_url(storage_key, expires_in)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
self._storage.delete(path.get_storage_key())
|
||||
except Exception:
|
||||
import logging
|
||||
|
||||
return self._generate_signed_proxy_download_url(asset_path, expires_in)
|
||||
|
||||
def get_download_urls(
|
||||
self,
|
||||
asset_paths: Iterable[SignedAssetPath],
|
||||
expires_in: int = 3600,
|
||||
) -> list[str]:
|
||||
asset_paths_list = list(asset_paths)
|
||||
storage_keys = [self.get_storage_key(asset_path) for asset_path in asset_paths_list]
|
||||
|
||||
try:
|
||||
return self._storage.get_download_urls(storage_keys, expires_in)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return [self._generate_signed_proxy_download_url(asset_path, expires_in) for asset_path in asset_paths_list]
|
||||
|
||||
def get_upload_url(
|
||||
self,
|
||||
asset_path: SignedAssetPath,
|
||||
expires_in: int = 3600,
|
||||
) -> str:
|
||||
storage_key = self.get_storage_key(asset_path)
|
||||
try:
|
||||
return self._storage.get_upload_url(storage_key, expires_in)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
return self._generate_signed_proxy_upload_url(asset_path, expires_in)
|
||||
|
||||
def _generate_signed_proxy_download_url(self, asset_path: SignedAssetPath, expires_in: int) -> str:
|
||||
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
|
||||
expires_at = int(time.time()) + max(expires_in, 1)
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = AppAssetSigner.create_download_signature(asset_path=asset_path, expires_at=expires_at, nonce=nonce)
|
||||
|
||||
base_url = dify_config.FILES_URL
|
||||
url = self._build_proxy_url(base_url=base_url, asset_path=asset_path, action="download")
|
||||
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
def _generate_signed_proxy_upload_url(self, asset_path: SignedAssetPath, expires_in: int) -> str:
|
||||
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
|
||||
expires_at = int(time.time()) + max(expires_in, 1)
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = AppAssetSigner.create_upload_signature(asset_path=asset_path, expires_at=expires_at, nonce=nonce)
|
||||
|
||||
base_url = dify_config.FILES_URL
|
||||
url = self._build_proxy_url(base_url=base_url, asset_path=asset_path, action="upload")
|
||||
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
@staticmethod
|
||||
def _build_proxy_url(*, base_url: str, asset_path: SignedAssetPath, action: str) -> str:
|
||||
encoded_parts = [urllib.parse.quote(part, safe="") for part in asset_path.proxy_path_parts()]
|
||||
path = "/".join(encoded_parts)
|
||||
return f"{base_url}/files/app-assets/{path}/{action}"
|
||||
logging.getLogger(__name__).debug("Failed to delete import zip: %s", path.get_storage_key())
|
||||
|
||||
@ -1,5 +1 @@
|
||||
from .source_zip_extractor import SourceZipExtractor
|
||||
|
||||
__all__ = [
|
||||
"SourceZipExtractor",
|
||||
]
|
||||
# App bundle utilities - manifest-driven import/export handled by AppBundleService
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app.entities.app_bundle_entities import ExtractedFile, ExtractedFolder, ZipSecurityError
|
||||
from core.app_assets.storage import AssetPath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app_assets.storage import AppAssetStorage
|
||||
|
||||
|
||||
class SourceZipExtractor:
|
||||
def __init__(self, storage: AppAssetStorage) -> None:
|
||||
self._storage = storage
|
||||
|
||||
def extract_entries(
|
||||
self, zip_bytes: bytes, *, expected_prefix: str
|
||||
) -> tuple[list[ExtractedFolder], list[ExtractedFile]]:
|
||||
folders: list[ExtractedFolder] = []
|
||||
files: list[ExtractedFile] = []
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
|
||||
for info in zf.infolist():
|
||||
name = info.filename
|
||||
self._validate_path(name)
|
||||
|
||||
if not name.startswith(expected_prefix):
|
||||
continue
|
||||
|
||||
relative_path = name[len(expected_prefix) :].lstrip("/")
|
||||
if not relative_path:
|
||||
continue
|
||||
|
||||
if info.is_dir():
|
||||
folders.append(ExtractedFolder(path=relative_path.rstrip("/")))
|
||||
else:
|
||||
content = zf.read(info)
|
||||
files.append(ExtractedFile(path=relative_path, content=content))
|
||||
|
||||
return folders, files
|
||||
|
||||
def build_tree_and_save(
|
||||
self,
|
||||
folders: list[ExtractedFolder],
|
||||
files: list[ExtractedFile],
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
) -> AppAssetFileTree:
|
||||
tree = AppAssetFileTree()
|
||||
path_to_node_id: dict[str, str] = {}
|
||||
|
||||
all_folder_paths = {f.path for f in folders}
|
||||
for file in files:
|
||||
self._ensure_parent_folders(file.path, all_folder_paths)
|
||||
|
||||
sorted_folders = sorted(all_folder_paths, key=lambda p: p.count("/"))
|
||||
for folder_path in sorted_folders:
|
||||
node_id = str(uuid4())
|
||||
name = folder_path.rsplit("/", 1)[-1]
|
||||
parent_path = folder_path.rsplit("/", 1)[0] if "/" in folder_path else None
|
||||
parent_id = path_to_node_id.get(parent_path) if parent_path else None
|
||||
|
||||
node = AppAssetNode.create_folder(node_id, name, parent_id)
|
||||
tree.add(node)
|
||||
path_to_node_id[folder_path] = node_id
|
||||
|
||||
sorted_files = sorted(files, key=lambda f: f.path)
|
||||
for file in sorted_files:
|
||||
node_id = str(uuid4())
|
||||
name = file.path.rsplit("/", 1)[-1]
|
||||
parent_path = file.path.rsplit("/", 1)[0] if "/" in file.path else None
|
||||
parent_id = path_to_node_id.get(parent_path) if parent_path else None
|
||||
|
||||
node = AppAssetNode.create_file(node_id, name, parent_id, len(file.content))
|
||||
tree.add(node)
|
||||
|
||||
asset_path = AssetPath.draft(tenant_id, app_id, node_id)
|
||||
self._storage.save(asset_path, file.content)
|
||||
|
||||
return tree
|
||||
|
||||
def _validate_path(self, path: str) -> None:
|
||||
if ".." in path:
|
||||
raise ZipSecurityError(f"Path traversal detected: {path}")
|
||||
if path.startswith("/"):
|
||||
raise ZipSecurityError(f"Absolute path detected: {path}")
|
||||
if "\\" in path:
|
||||
raise ZipSecurityError(f"Backslash in path: {path}")
|
||||
|
||||
def _ensure_parent_folders(self, file_path: str, folder_set: set[str]) -> None:
|
||||
parts = file_path.split("/")[:-1]
|
||||
for i in range(1, len(parts) + 1):
|
||||
parent = "/".join(parts[:i])
|
||||
folder_set.add(parent)
|
||||
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
content: str
|
||||
summary: str | None = None
|
||||
child_chunks: list[str] | None = None
|
||||
|
||||
|
||||
|
||||
@ -123,6 +123,8 @@ def download(f: File, /):
|
||||
):
|
||||
return _download_file_content(f.storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
@ -153,6 +155,8 @@ def _download_file_content(path: str, /):
|
||||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
|
||||
@ -4,8 +4,10 @@ Proxy requests to avoid SSRF
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
Headers: TypeAlias = dict[str, str]
|
||||
_HEADERS_ADAPTER = TypeAdapter(Headers)
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
_SSRF_CLIENT_LIMITS = httpx.Limits(
|
||||
@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
||||
)
|
||||
|
||||
|
||||
def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
def _get_user_provided_host_header(headers: Headers | None) -> str | None:
|
||||
"""
|
||||
Extract the user-provided Host header from the headers dict.
|
||||
|
||||
@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
def _inject_trace_headers(headers: Headers | None) -> Headers:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
if not isinstance(verify_option, bool):
|
||||
raise ValueError("ssl_verify must be a boolean")
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
try:
|
||||
headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
|
||||
except ValidationError as e:
|
||||
raise ValueError("headers must be a mapping of string keys to string values") from e
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("GET", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("POST", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("PUT", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
@ -311,14 +311,18 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
@ -361,6 +365,12 @@ class IndexingRunner:
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
|
||||
@ -471,7 +471,6 @@ class LLMGenerator:
|
||||
prompt_messages=complete_messages,
|
||||
output_model=CodeNodeStructuredOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=True,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@ -553,16 +552,10 @@ class LLMGenerator:
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=completion_params,
|
||||
stream=True,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
|
||||
model_instance=model_instance, prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=completion_params, tenant_id=tenant_id)
|
||||
|
||||
return {"questions": response.questions, "error": ""}
|
||||
|
||||
@ -842,15 +835,11 @@ Generate {language} code to extract/transform available variables for the target
|
||||
try:
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=list(prompt_messages),
|
||||
output_model=InstructionModifyOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=True,
|
||||
)
|
||||
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=list(prompt_messages),
|
||||
output_model=InstructionModifyOutput,
|
||||
model_parameters=model_parameters)
|
||||
return response.model_dump(mode="python")
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypeVar, cast, overload
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import json_repair
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
@ -14,13 +14,9 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
@ -52,7 +48,6 @@ TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, Mode
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
@ -63,58 +58,10 @@ def invoke_llm_with_structured_output(
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with structured output.
|
||||
|
||||
@ -129,7 +76,6 @@ def invoke_llm_with_structured_output(
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
@ -165,91 +111,33 @@ def invoke_llm_with_structured_output(
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
# Non-streaming result
|
||||
structured_output = _extract_structured_output(llm_result)
|
||||
# Non-streaming result
|
||||
structured_output = _extract_structured_output(llm_result)
|
||||
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
tool_call_args: dict[str, str] = {} # tool_call_id -> arguments
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: str | None = None
|
||||
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
# Collect text content
|
||||
result_text += event.delta.message.get_text_content()
|
||||
# Collect tool call arguments
|
||||
if event.delta.message.tool_calls:
|
||||
for tool_call in event.delta.message.tool_calls:
|
||||
call_id = tool_call.id or ""
|
||||
if tool_call.function.arguments:
|
||||
tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
# Extract structured output: prefer tool call, fallback to text
|
||||
structured_output = _extract_structured_output_from_stream(result_text, tool_call_args)
|
||||
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
|
||||
|
||||
def invoke_llm_with_pydantic_model(
|
||||
@ -262,7 +150,6 @@ def invoke_llm_with_pydantic_model(
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True, # Some model plugin implementations don't support stream=False
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
@ -281,36 +168,6 @@ def invoke_llm_with_pydantic_model(
|
||||
"""
|
||||
json_schema = _schema_from_pydantic(output_model)
|
||||
|
||||
if stream:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Consume the generator to get the final chunk with structured_output
|
||||
last_chunk: LLMResultChunkWithStructuredOutput | None = None
|
||||
for chunk in result_generator:
|
||||
last_chunk = chunk
|
||||
|
||||
if last_chunk is None:
|
||||
raise OutputParserError("No chunks received from LLM")
|
||||
|
||||
structured_output = last_chunk.structured_output
|
||||
if structured_output is None:
|
||||
raise OutputParserError("Structured output is empty")
|
||||
|
||||
return _validate_structured_output(output_model, structured_output)
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
@ -320,7 +177,6 @@ def invoke_llm_with_pydantic_model(
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
@ -416,7 +272,7 @@ def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
|
||||
repaired = json_repair.loads(arguments)
|
||||
if not isinstance(repaired, dict):
|
||||
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
||||
return cast(dict, repaired)
|
||||
return repaired
|
||||
|
||||
|
||||
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
|
||||
@ -435,3 +435,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
"""Summarize the following content. Extract only the key information and main points. """
|
||||
"""Remove redundant details.
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
from threading import Lock
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
@ -144,34 +148,60 @@ class AIModel(BaseModel):
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
# sort credentials
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
contexts.plugin_model_schemas.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_schemas.set({})
|
||||
contexts.plugin_model_schema_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_schema_lock.get():
|
||||
if cache_key in contexts.plugin_model_schemas.get():
|
||||
return contexts.plugin_model_schemas.get()[cache_key]
|
||||
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if schema:
|
||||
contexts.plugin_model_schemas.get()[cache_key] = schema
|
||||
schema = plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.provider_name,
|
||||
model_type=self.model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
return schema
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
|
||||
@ -5,7 +5,11 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -175,34 +180,60 @@ class ModelProviderFactory:
|
||||
"""
|
||||
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
|
||||
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
|
||||
# sort credentials
|
||||
sorted_credentials = sorted(credentials.items()) if credentials else []
|
||||
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
|
||||
|
||||
cached_schema_json = None
|
||||
try:
|
||||
contexts.plugin_model_schemas.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_schemas.set({})
|
||||
contexts.plugin_model_schema_lock.set(Lock())
|
||||
|
||||
with contexts.plugin_model_schema_lock.get():
|
||||
if cache_key in contexts.plugin_model_schemas.get():
|
||||
return contexts.plugin_model_schemas.get()[cache_key]
|
||||
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
cached_schema_json = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to read plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
if cached_schema_json:
|
||||
try:
|
||||
return AIModelEntity.model_validate_json(cached_schema_json)
|
||||
except ValidationError:
|
||||
logger.warning(
|
||||
"Failed to validate cached plugin model schema for model %s",
|
||||
model,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if schema:
|
||||
contexts.plugin_model_schemas.get()[cache_key] = schema
|
||||
schema = self.plugin_model_manager.get_model_schema(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="unknown",
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
model_type=model_type.value,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
return schema
|
||||
if schema:
|
||||
try:
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
|
||||
except (RedisError, RuntimeError) as exc:
|
||||
logger.warning(
|
||||
"Failed to write plugin model schema cache for model %s: %s",
|
||||
model,
|
||||
str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
|
||||
@ -114,46 +114,32 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=payload.prompt_messages,
|
||||
json_schema=payload.structured_output_schema,
|
||||
model_parameters=payload.completion_params,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
user=user_id
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
|
||||
@ -24,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DocumentSegment,
|
||||
DocumentSegmentSummary,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
@ -389,15 +395,15 @@ class RetrievalService:
|
||||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids: list[Any] = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
summary_segment_ids = set() # Track segments retrieved via summary
|
||||
summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score
|
||||
|
||||
# First pass: collect all document IDs and identify summary documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
@ -408,16 +414,39 @@ class RetrievalService:
|
||||
continue
|
||||
valid_dataset_documents[document_id] = dataset_document
|
||||
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if original_chunk_id:
|
||||
summary_segment_ids.add(original_chunk_id)
|
||||
# Save summary's score for later use
|
||||
summary_score = document.metadata.get("score")
|
||||
if summary_score is not None:
|
||||
try:
|
||||
summary_score_float = float(summary_score)
|
||||
# If the same segment has multiple summary hits, take the highest score
|
||||
if original_chunk_id not in summary_score_map:
|
||||
summary_score_map[original_chunk_id] = summary_score_float
|
||||
else:
|
||||
summary_score_map[original_chunk_id] = max(
|
||||
summary_score_map[original_chunk_id], summary_score_float
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
# Skip invalid score values
|
||||
pass
|
||||
continue # Skip adding to other lists for summary documents
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
child_index_node_ids.append(doc_id)
|
||||
else:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
@ -433,6 +462,7 @@ class RetrievalService:
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
@ -447,6 +477,7 @@ class RetrievalService:
|
||||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
@ -470,6 +501,7 @@ class RetrievalService:
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
@ -481,6 +513,40 @@ class RetrievalService:
|
||||
if index_node_segments:
|
||||
segments.extend(index_node_segments)
|
||||
|
||||
# Handle summary documents: query segments by original_chunk_id
|
||||
if summary_segment_ids:
|
||||
summary_segment_ids_list = list(summary_segment_ids)
|
||||
summary_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id.in_(summary_segment_ids_list),
|
||||
)
|
||||
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
|
||||
segments.extend(summary_segments)
|
||||
# Add summary segment IDs to segment_ids for summary query
|
||||
for seg in summary_segments:
|
||||
if seg.id not in segment_ids:
|
||||
segment_ids.append(seg.id)
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
@ -489,30 +555,44 @@ class RetrievalService:
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
# Check if this segment was retrieved via summary
|
||||
# Use summary score as base score if available, otherwise 0.0
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
max_score = 0.0
|
||||
for child_chunk in child_chunks:
|
||||
document = doc_to_document_map[child_chunk.index_node_id]
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"score": child_score,
|
||||
}
|
||||
child_chunk_details.append(child_chunk_detail)
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||
max_score = max(max_score, child_score)
|
||||
for attachment_info in attachment_infos:
|
||||
file_document = doc_to_document_map[attachment_info["id"]]
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
file_document = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
else:
|
||||
# No child chunks or attachments, use summary score if available
|
||||
summary_score = summary_score_map.get(segment.id)
|
||||
if summary_score is not None:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
"segment": segment,
|
||||
}
|
||||
@ -520,14 +600,23 @@ class RetrievalService:
|
||||
else:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
max_score = 0.0
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
|
||||
# Check if this segment was retrieved via summary
|
||||
# Use summary score if available (summary retrieval takes priority)
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
# If not retrieved via summary, use original segment's score
|
||||
if segment.id not in summary_score_map:
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
|
||||
# Also consider attachment scores
|
||||
for attachment_info in attachment_infos:
|
||||
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
@ -576,9 +665,16 @@ class RetrievalService:
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract summary if this segment was retrieved via summary
|
||||
summary_content = segment_summary_map.get(segment.id)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
segment=segment,
|
||||
child_chunks=child_chunks_list,
|
||||
score=score,
|
||||
files=files,
|
||||
summary=summary_content,
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
|
||||
@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel):
|
||||
doc_metadata: dict[str, Any] | None = None
|
||||
title: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
summary: str | None = None
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
"""Word (.docx) document extractor used for RAG ingestion.
|
||||
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
@ -8,7 +11,6 @@ import tempfile
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
response = httpx.get(self.file_path, timeout=None)
|
||||
response = ssrf_proxy.get(self.file_path)
|
||||
|
||||
if response.status_code != 200:
|
||||
response.close()
|
||||
@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor):
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
|
||||
try:
|
||||
self.temp_file.write(response.content)
|
||||
self.temp_file.flush()
|
||||
finally:
|
||||
response.close()
|
||||
self.file_path = self.temp_file.name
|
||||
|
||||
@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC):
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
|
||||
@ -1,9 +1,27 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
from models import UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item."""
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_texts))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
|
||||
futures = [executor.submit(process, preview) for preview in preview_texts]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode (indexing-estimate), if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise ValueError(error_summary)
|
||||
|
||||
return preview_texts
|
||||
|
||||
@staticmethod
|
||||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
and supports vision models by including images from the segment attachments or text content.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
|
||||
Returns:
|
||||
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
|
||||
"""
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
|
||||
|
||||
model_name = summary_index_setting.get("model_name")
|
||||
model_provider_name = summary_index_setting.get("model_provider_name")
|
||||
summary_prompt = summary_index_setting.get("summary_prompt")
|
||||
|
||||
if not model_name or not model_provider_name:
|
||||
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
|
||||
|
||||
# Import default summary prompt
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle, model_name)
|
||||
|
||||
# Get model schema to check if vision is supported
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
|
||||
|
||||
# Extract images if model supports vision
|
||||
image_files = []
|
||||
if supports_vision:
|
||||
# First, try to get images from SegmentAttachmentBinding (preferred method)
|
||||
if segment_id:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
|
||||
|
||||
# If no images from attachments, fall back to extracting from text
|
||||
if not image_files:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
|
||||
|
||||
# Build prompt messages
|
||||
prompt_messages = []
|
||||
|
||||
if image_files:
|
||||
# If we have images, create a UserPromptMessage with both text and images
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
|
||||
# Add images first
|
||||
for file in image_files:
|
||||
try:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
prompt_message_contents.append(file_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
|
||||
continue
|
||||
|
||||
# Add text content
|
||||
if prompt_message_contents: # Only add text if we successfully added images
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
# If image conversion failed, fall back to text-only
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
else:
|
||||
# No images, use simple text prompt
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False
|
||||
)
|
||||
|
||||
# Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator
|
||||
if not isinstance(result, LLMResult):
|
||||
raise ValueError("Expected LLMResult when stream=False")
|
||||
|
||||
summary_content = getattr(result.message, "content", "")
|
||||
usage = result.usage
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
try:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
except Exception as e:
|
||||
# Log but don't fail summary generation if quota deduction fails
|
||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||
|
||||
return summary_content, usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
|
||||
"""
|
||||
Extract images from markdown text and convert them to File objects.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content that may contain markdown image links
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in the text
|
||||
"""
|
||||
# Extract markdown images using regex pattern
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
images = re.findall(pattern, text)
|
||||
|
||||
if not images:
|
||||
return []
|
||||
|
||||
upload_file_id_list = []
|
||||
|
||||
for image in images:
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
# Tool files are handled differently, skip for now
|
||||
continue
|
||||
|
||||
if not upload_file_id_list:
|
||||
return []
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create File objects from UploadFile records
|
||||
file_objects = []
|
||||
for upload_file in upload_files:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
mapping = {
|
||||
"upload_file_id": upload_file.id,
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
|
||||
"type": FileType.IMAGE.value,
|
||||
}
|
||||
|
||||
try:
|
||||
file_obj = build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
|
||||
"""
|
||||
Extract images from SegmentAttachmentBinding table (preferred method).
|
||||
This matches how DatasetRetrieval gets segment attachments.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
segment_id: Segment ID to fetch attachments for
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in segment attachments
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# Query attachments from SegmentAttachmentBinding table
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment_id,
|
||||
SegmentAttachmentBinding.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
|
||||
file_objects = []
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
|
||||
|
||||
Note: For parent-child structure, we only generate summaries for parent chunks.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item (parent chunk)."""
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
preview.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_texts))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
|
||||
futures = [executor.submit(process, preview) for preview in preview_texts]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode (indexing-estimate), if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise ValueError(error_summary)
|
||||
|
||||
return preview_texts
|
||||
|
||||
@ -11,6 +11,8 @@ import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Note: qa_model doesn't generate summaries, but we clean them for completeness
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
Note: QA model uses question-answer pairs, which don't require summary generation.
|
||||
"""
|
||||
# QA model doesn't generate summaries, return as-is
|
||||
return preview_texts
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
|
||||
@ -236,20 +236,24 @@ class DatasetRetrieval:
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
# Build content: if summary exists, add it before the segment content
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
segment_content = segment.get_sign_content()
|
||||
|
||||
# If summary exists, prepend it to the content
|
||||
if record.summary:
|
||||
final_content = f"{record.summary}\n{segment_content}"
|
||||
else:
|
||||
final_content = segment_content
|
||||
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=final_content,
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
if vision_enabled:
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
@ -316,6 +320,9 @@ class DatasetRetrieval:
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
# Add summary if this segment was retrieved via summary
|
||||
if hasattr(record, "summary") and record.summary:
|
||||
source.summary = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
|
||||
@ -7,9 +7,9 @@ from uuid import UUID, uuid4
|
||||
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
|
||||
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath
|
||||
from core.sandbox.storage import sandbox_file_storage
|
||||
from core.sandbox.storage.archive_storage import SandboxArchivePath
|
||||
from core.sandbox.storage.sandbox_file_storage import SandboxFileDownloadPath
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from extensions.ext_storage import storage
|
||||
@ -68,15 +68,14 @@ print(json.dumps(entries))
|
||||
|
||||
def _get_archive_download_url(self) -> str:
|
||||
"""Get a pre-signed download URL for the sandbox archive."""
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
archive_path = SandboxArchivePath(tenant_id=UUID(self._tenant_id), sandbox_id=UUID(self._sandbox_id))
|
||||
storage_key = archive_path.get_storage_key()
|
||||
if not storage.exists(storage_key):
|
||||
raise ValueError("Sandbox archive not found")
|
||||
return SandboxArchiveSigner.build_signed_url(
|
||||
archive_path=archive_path,
|
||||
expires_in=self._EXPORT_EXPIRES_IN_SECONDS,
|
||||
action=SandboxArchiveSigner.OPERATION_DOWNLOAD,
|
||||
)
|
||||
presign_storage = FilePresignStorage(storage.storage_runner)
|
||||
return presign_storage.get_download_url(storage_key, self._EXPORT_EXPIRES_IN_SECONDS)
|
||||
|
||||
def _create_zip_sandbox(self) -> ZipSandbox:
|
||||
"""Create a ZipSandbox instance for archive operations."""
|
||||
|
||||
@ -7,8 +7,8 @@ from uuid import UUID, uuid4
|
||||
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath
|
||||
from core.sandbox.storage import sandbox_file_storage
|
||||
from core.sandbox.storage.sandbox_file_storage import SandboxFileDownloadPath
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
@ -1 +0,0 @@
|
||||
"""Sandbox security helpers."""
|
||||
@ -1,152 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from configs import dify_config
|
||||
from libs import rsa
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxArchivePath:
|
||||
tenant_id: UUID
|
||||
sandbox_id: UUID
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"sandbox/{self.tenant_id}/{self.sandbox_id}.tar.gz"
|
||||
|
||||
def proxy_path(self) -> str:
|
||||
return f"{self.tenant_id}/{self.sandbox_id}"
|
||||
|
||||
|
||||
class SandboxArchiveSigner:
|
||||
SIGNATURE_PREFIX = "sandbox-archive"
|
||||
SIGNATURE_VERSION = "v1"
|
||||
OPERATION_DOWNLOAD = "download"
|
||||
OPERATION_UPLOAD = "upload"
|
||||
|
||||
@classmethod
|
||||
def create_download_signature(cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str) -> str:
|
||||
return cls._create_signature(
|
||||
archive_path=archive_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_upload_signature(cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str) -> str:
|
||||
return cls._create_signature(
|
||||
archive_path=archive_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_download_signature(
|
||||
cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str, sign: str
|
||||
) -> bool:
|
||||
return cls._verify_signature(
|
||||
archive_path=archive_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_upload_signature(cls, archive_path: SandboxArchivePath, expires_at: int, nonce: str, sign: str) -> bool:
|
||||
return cls._verify_signature(
|
||||
archive_path=archive_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _verify_signature(
|
||||
cls,
|
||||
*,
|
||||
archive_path: SandboxArchivePath,
|
||||
operation: str,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
if expires_at <= 0:
|
||||
return False
|
||||
|
||||
expected_sign = cls._create_signature(
|
||||
archive_path=archive_path,
|
||||
operation=operation,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
if not hmac.compare_digest(sign, expected_sign):
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
if expires_at < current_time:
|
||||
return False
|
||||
|
||||
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def build_signed_url(
|
||||
cls,
|
||||
*,
|
||||
archive_path: SandboxArchivePath,
|
||||
expires_in: int,
|
||||
action: str,
|
||||
) -> str:
|
||||
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
|
||||
expires_at = int(time.time()) + max(expires_in, 1)
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = cls._create_signature(
|
||||
archive_path=archive_path,
|
||||
operation=action,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
base_url = dify_config.FILES_URL
|
||||
url = f"{base_url}/files/sandbox-archives/{archive_path.proxy_path()}/{action}"
|
||||
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
@classmethod
|
||||
def _create_signature(
|
||||
cls,
|
||||
*,
|
||||
archive_path: SandboxArchivePath,
|
||||
operation: str,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
) -> str:
|
||||
key = cls._tenant_key(str(archive_path.tenant_id))
|
||||
message = (
|
||||
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
|
||||
f"{archive_path.tenant_id}|{archive_path.sandbox_id}|{expires_at}|{nonce}"
|
||||
)
|
||||
sign = hmac.new(key, message.encode(), hashlib.sha256).digest()
|
||||
return base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@classmethod
|
||||
def _tenant_key(cls, tenant_id: str) -> bytes:
|
||||
try:
|
||||
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
|
||||
except rsa.PrivkeyNotFoundError as exc:
|
||||
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
|
||||
private_key = rsa_key.export_key()
|
||||
return hashlib.sha256(private_key).digest()
|
||||
@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from configs import dify_config
|
||||
from libs import rsa
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxFileDownloadPath:
|
||||
tenant_id: UUID
|
||||
sandbox_id: UUID
|
||||
export_id: str
|
||||
filename: str
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"sandbox_file_downloads/{self.tenant_id}/{self.sandbox_id}/{self.export_id}/{self.filename}"
|
||||
|
||||
def proxy_path(self) -> str:
|
||||
encoded_parts = [
|
||||
urllib.parse.quote(str(self.tenant_id), safe=""),
|
||||
urllib.parse.quote(str(self.sandbox_id), safe=""),
|
||||
urllib.parse.quote(self.export_id, safe=""),
|
||||
urllib.parse.quote(self.filename, safe=""),
|
||||
]
|
||||
return "/".join(encoded_parts)
|
||||
|
||||
|
||||
class SandboxFileSigner:
|
||||
SIGNATURE_PREFIX = "sandbox-file-download"
|
||||
SIGNATURE_VERSION = "v1"
|
||||
OPERATION_DOWNLOAD = "download"
|
||||
OPERATION_UPLOAD = "upload"
|
||||
|
||||
@classmethod
|
||||
def build_signed_url(
|
||||
cls,
|
||||
*,
|
||||
export_path: SandboxFileDownloadPath,
|
||||
expires_in: int,
|
||||
action: str,
|
||||
) -> str:
|
||||
expires_in = min(expires_in, dify_config.FILES_ACCESS_TIMEOUT)
|
||||
expires_at = int(time.time()) + max(expires_in, 1)
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = cls._create_signature(
|
||||
export_path=export_path,
|
||||
operation=action,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
base_url = dify_config.FILES_URL
|
||||
url = f"{base_url}/files/sandbox-file-downloads/{export_path.proxy_path()}/{action}"
|
||||
query = urllib.parse.urlencode({"expires_at": expires_at, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
@classmethod
|
||||
def verify_download_signature(
|
||||
cls,
|
||||
*,
|
||||
export_path: SandboxFileDownloadPath,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
return cls._verify_signature(
|
||||
export_path=export_path,
|
||||
operation=cls.OPERATION_DOWNLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_upload_signature(
|
||||
cls,
|
||||
*,
|
||||
export_path: SandboxFileDownloadPath,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
return cls._verify_signature(
|
||||
export_path=export_path,
|
||||
operation=cls.OPERATION_UPLOAD,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _verify_signature(
|
||||
cls,
|
||||
*,
|
||||
export_path: SandboxFileDownloadPath,
|
||||
operation: str,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
if expires_at <= 0:
|
||||
return False
|
||||
|
||||
expected_sign = cls._create_signature(
|
||||
export_path=export_path,
|
||||
operation=operation,
|
||||
expires_at=expires_at,
|
||||
nonce=nonce,
|
||||
)
|
||||
if not hmac.compare_digest(sign, expected_sign):
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
if expires_at < current_time:
|
||||
return False
|
||||
|
||||
if expires_at - current_time > dify_config.FILES_ACCESS_TIMEOUT:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _create_signature(
|
||||
cls,
|
||||
*,
|
||||
export_path: SandboxFileDownloadPath,
|
||||
operation: str,
|
||||
expires_at: int,
|
||||
nonce: str,
|
||||
) -> str:
|
||||
key = cls._tenant_key(str(export_path.tenant_id))
|
||||
message = (
|
||||
f"{cls.SIGNATURE_PREFIX}|{cls.SIGNATURE_VERSION}|{operation}|"
|
||||
f"{export_path.tenant_id}|{export_path.sandbox_id}|{export_path.export_id}|{export_path.filename}|"
|
||||
f"{expires_at}|{nonce}"
|
||||
)
|
||||
digest = hmac.new(key, message.encode(), hashlib.sha256).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode()
|
||||
|
||||
@classmethod
|
||||
def _tenant_key(cls, tenant_id: str) -> bytes:
|
||||
try:
|
||||
rsa_key, _ = rsa.get_decrypt_decoding(tenant_id)
|
||||
except rsa.PrivkeyNotFoundError as exc:
|
||||
raise ValueError(f"Tenant private key missing for tenant_id={tenant_id}") from exc
|
||||
private_key = rsa_key.export_key()
|
||||
return hashlib.sha256(private_key).digest()
|
||||
@ -1,11 +1,13 @@
|
||||
from .archive_storage import ArchiveSandboxStorage
|
||||
from .archive_storage import ArchiveSandboxStorage, SandboxArchivePath
|
||||
from .noop_storage import NoopSandboxStorage
|
||||
from .sandbox_file_storage import SandboxFileStorage, sandbox_file_storage
|
||||
from .sandbox_file_storage import SandboxFileDownloadPath, SandboxFileStorage, sandbox_file_storage
|
||||
from .sandbox_storage import SandboxStorage
|
||||
|
||||
__all__ = [
|
||||
"ArchiveSandboxStorage",
|
||||
"NoopSandboxStorage",
|
||||
"SandboxArchivePath",
|
||||
"SandboxFileDownloadPath",
|
||||
"SandboxFileStorage",
|
||||
"SandboxStorage",
|
||||
"sandbox_file_storage",
|
||||
|
||||
@ -1,18 +1,31 @@
|
||||
"""Archive-based sandbox storage for persisting sandbox state.
|
||||
|
||||
This module provides storage operations for sandbox workspace archives (tar.gz),
|
||||
enabling state persistence across sandbox sessions.
|
||||
|
||||
Storage key format: sandbox/{tenant_id}/{sandbox_id}.tar.gz
|
||||
|
||||
All presign operations use the unified FilePresignStorage wrapper, which automatically
|
||||
falls back to Dify's file proxy when the underlying storage doesn't support presigned URLs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from core.sandbox.security.archive_signer import SandboxArchivePath, SandboxArchiveSigner
|
||||
from core.virtual_environment.__base.exec import PipelineExecutionError
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
from .sandbox_storage import SandboxStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKSPACE_DIR = "."
|
||||
|
||||
ARCHIVE_DOWNLOAD_TIMEOUT = 60 * 5
|
||||
ARCHIVE_UPLOAD_TIMEOUT = 60 * 5
|
||||
|
||||
@ -21,40 +34,67 @@ def build_tar_exclude_args(patterns: list[str]) -> list[str]:
|
||||
return [f"--exclude={p}" for p in patterns]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxArchivePath:
|
||||
"""Path for sandbox workspace archives."""
|
||||
|
||||
tenant_id: UUID
|
||||
sandbox_id: UUID
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"sandbox/{self.tenant_id}/{self.sandbox_id}.tar.gz"
|
||||
|
||||
|
||||
class ArchiveSandboxStorage(SandboxStorage):
|
||||
"""Archive-based storage for sandbox workspace persistence.
|
||||
|
||||
Uses tar.gz archives to save and restore sandbox workspace state.
|
||||
Requires a presign-capable storage wrapper for generating download/upload URLs.
|
||||
"""
|
||||
|
||||
_tenant_id: str
|
||||
_sandbox_id: str
|
||||
_exclude_patterns: list[str]
|
||||
_storage: FilePresignStorage
|
||||
|
||||
def __init__(self, tenant_id: str, sandbox_id: str, exclude_patterns: list[str] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
sandbox_id: str,
|
||||
storage: BaseStorage,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._sandbox_id = sandbox_id
|
||||
self._exclude_patterns = exclude_patterns or []
|
||||
# Wrap with FilePresignStorage for presign fallback support
|
||||
self._storage = FilePresignStorage(storage)
|
||||
|
||||
@property
|
||||
def _archive_path(self) -> SandboxArchivePath:
|
||||
return SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id))
|
||||
|
||||
@property
|
||||
def _storage_key(self) -> str:
|
||||
return SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id)).get_storage_key()
|
||||
return self._archive_path.get_storage_key()
|
||||
|
||||
@property
|
||||
def _archive_name(self) -> str:
|
||||
return f"{self._sandbox_id}.tar.gz"
|
||||
|
||||
@property
|
||||
def _archive_path(self) -> str:
|
||||
def _archive_tmp_path(self) -> str:
|
||||
return f"/tmp/{self._archive_name}"
|
||||
|
||||
def mount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
"""Load archive from storage into sandbox workspace."""
|
||||
if not self.exists():
|
||||
logger.debug("No archive found for sandbox %s, skipping mount", self._sandbox_id)
|
||||
return False
|
||||
|
||||
archive_path = SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id))
|
||||
download_url = SandboxArchiveSigner.build_signed_url(
|
||||
archive_path=archive_path,
|
||||
expires_in=ARCHIVE_DOWNLOAD_TIMEOUT,
|
||||
action=SandboxArchiveSigner.OPERATION_DOWNLOAD,
|
||||
)
|
||||
download_url = self._storage.get_download_url(self._storage_key, ARCHIVE_DOWNLOAD_TIMEOUT)
|
||||
archive_name = self._archive_name
|
||||
|
||||
try:
|
||||
(
|
||||
pipeline(sandbox)
|
||||
@ -74,13 +114,10 @@ class ArchiveSandboxStorage(SandboxStorage):
|
||||
return True
|
||||
|
||||
def unmount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
archive_path = SandboxArchivePath(UUID(self._tenant_id), UUID(self._sandbox_id))
|
||||
upload_url = SandboxArchiveSigner.build_signed_url(
|
||||
archive_path=archive_path,
|
||||
expires_in=ARCHIVE_UPLOAD_TIMEOUT,
|
||||
action=SandboxArchiveSigner.OPERATION_UPLOAD,
|
||||
)
|
||||
archive_path = self._archive_path
|
||||
"""Save sandbox workspace to storage as archive."""
|
||||
upload_url = self._storage.get_upload_url(self._storage_key, ARCHIVE_UPLOAD_TIMEOUT)
|
||||
archive_path = self._archive_tmp_path
|
||||
|
||||
(
|
||||
pipeline(sandbox)
|
||||
.add(
|
||||
@ -105,11 +142,13 @@ class ArchiveSandboxStorage(SandboxStorage):
|
||||
return True
|
||||
|
||||
def exists(self) -> bool:
|
||||
return storage.exists(self._storage_key)
|
||||
"""Check if archive exists in storage."""
|
||||
return self._storage.exists(self._storage_key)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete archive from storage."""
|
||||
try:
|
||||
storage.delete(self._storage_key)
|
||||
self._storage.delete(self._storage_key)
|
||||
logger.info("Deleted archive for sandbox %s", self._sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete archive for sandbox %s", self._sandbox_id)
|
||||
|
||||
@ -1,23 +1,58 @@
|
||||
"""Sandbox file storage for exporting files from sandbox environments.
|
||||
|
||||
This module provides storage operations for files exported from sandbox environments,
|
||||
including download tickets for both runtime and archive-based file sources.
|
||||
|
||||
Storage key format: sandbox_file_downloads/{tenant_id}/{sandbox_id}/{export_id}/{filename}
|
||||
|
||||
All presign operations use the unified FilePresignStorage wrapper, which automatically
|
||||
falls back to Dify's file proxy when the underlying storage doesn't support presigned URLs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from core.sandbox.security.sandbox_file_signer import SandboxFileDownloadPath, SandboxFileSigner
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from extensions.storage.silent_storage import SilentStorage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxFileDownloadPath:
|
||||
"""Path for sandbox file exports."""
|
||||
|
||||
tenant_id: UUID
|
||||
sandbox_id: UUID
|
||||
export_id: str
|
||||
filename: str
|
||||
|
||||
def get_storage_key(self) -> str:
|
||||
return f"sandbox_file_downloads/{self.tenant_id}/{self.sandbox_id}/{self.export_id}/{self.filename}"
|
||||
|
||||
|
||||
class SandboxFileStorage:
|
||||
_base_storage: BaseStorage
|
||||
"""Storage operations for sandbox file exports.
|
||||
|
||||
Wraps BaseStorage with:
|
||||
- FilePresignStorage for presign fallback support
|
||||
- CachedPresignStorage for URL caching
|
||||
|
||||
Usage:
|
||||
storage = SandboxFileStorage(base_storage, redis_client=redis)
|
||||
storage.save(download_path, content)
|
||||
url = storage.get_download_url(download_path)
|
||||
"""
|
||||
|
||||
_storage: CachedPresignStorage
|
||||
|
||||
def __init__(self, storage: BaseStorage, *, redis_client: Any) -> None:
|
||||
self._base_storage = storage
|
||||
# Wrap with FilePresignStorage for fallback support, then CachedPresignStorage for caching
|
||||
presign_storage = FilePresignStorage(storage)
|
||||
self._storage = CachedPresignStorage(
|
||||
storage=storage,
|
||||
storage=presign_storage,
|
||||
redis_client=redis_client,
|
||||
cache_key_prefix="sandbox_file_downloads",
|
||||
)
|
||||
@ -26,29 +61,19 @@ class SandboxFileStorage:
|
||||
self._storage.save(download_path.get_storage_key(), content)
|
||||
|
||||
def get_download_url(self, download_path: SandboxFileDownloadPath, expires_in: int = 3600) -> str:
|
||||
storage_key = download_path.get_storage_key()
|
||||
try:
|
||||
return self._storage.get_download_url(storage_key, expires_in)
|
||||
except NotImplementedError:
|
||||
return SandboxFileSigner.build_signed_url(
|
||||
export_path=download_path,
|
||||
expires_in=expires_in,
|
||||
action=SandboxFileSigner.OPERATION_DOWNLOAD,
|
||||
)
|
||||
return self._storage.get_download_url(download_path.get_storage_key(), expires_in)
|
||||
|
||||
def get_upload_url(self, download_path: SandboxFileDownloadPath, expires_in: int = 3600) -> str:
|
||||
storage_key = download_path.get_storage_key()
|
||||
try:
|
||||
return self._storage.get_upload_url(storage_key, expires_in)
|
||||
except NotImplementedError:
|
||||
return SandboxFileSigner.build_signed_url(
|
||||
export_path=download_path,
|
||||
expires_in=expires_in,
|
||||
action=SandboxFileSigner.OPERATION_UPLOAD,
|
||||
)
|
||||
return self._storage.get_upload_url(download_path.get_storage_key(), expires_in)
|
||||
|
||||
|
||||
class _LazySandboxFileStorage:
|
||||
"""Lazy initializer for singleton SandboxFileStorage.
|
||||
|
||||
Delays storage initialization until first access, ensuring Flask app
|
||||
context is available.
|
||||
"""
|
||||
|
||||
_instance: SandboxFileStorage | None
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -56,12 +81,16 @@ class _LazySandboxFileStorage:
|
||||
|
||||
def _get_instance(self) -> SandboxFileStorage:
|
||||
if self._instance is None:
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if not hasattr(storage, "storage_runner"):
|
||||
raise RuntimeError(
|
||||
"Storage is not initialized; call storage.init_app before using sandbox_file_storage"
|
||||
)
|
||||
self._instance = SandboxFileStorage(
|
||||
storage=SilentStorage(storage.storage_runner), redis_client=redis_client
|
||||
storage=storage.storage_runner,
|
||||
redis_client=redis_client,
|
||||
)
|
||||
return self._instance
|
||||
|
||||
@ -69,4 +98,4 @@ class _LazySandboxFileStorage:
|
||||
return getattr(self._get_instance(), name)
|
||||
|
||||
|
||||
sandbox_file_storage = _LazySandboxFileStorage()
|
||||
sandbox_file_storage: SandboxFileStorage = _LazySandboxFileStorage() # type: ignore[assignment]
|
||||
|
||||
@ -2,21 +2,40 @@ import logging
|
||||
|
||||
from core.app_assets.storage import AssetPath
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillManager:
|
||||
_CACHE_KEY_PREFIX = "skill_bundle"
|
||||
_CACHE_TTL_SECONDS = 60 * 60 * 24
|
||||
|
||||
@staticmethod
|
||||
def get_cache_key(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
) -> str:
|
||||
return f"{SkillManager._CACHE_KEY_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
|
||||
|
||||
@staticmethod
|
||||
def load_bundle(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
) -> SkillBundle:
|
||||
cache_key = SkillManager.get_cache_key(tenant_id, app_id, assets_id)
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
return SkillBundle.model_validate_json(data)
|
||||
|
||||
asset_path = AssetPath.skill_bundle(tenant_id, app_id, assets_id)
|
||||
data = AppAssetService.get_storage().load(asset_path)
|
||||
return SkillBundle.model_validate_json(data)
|
||||
bundle = SkillBundle.model_validate_json(data)
|
||||
redis_client.setex(cache_key, SkillManager._CACHE_TTL_SECONDS, bundle.model_dump_json(indent=2).encode("utf-8"))
|
||||
return bundle
|
||||
|
||||
@staticmethod
|
||||
def save_bundle(
|
||||
@ -30,3 +49,5 @@ class SkillManager:
|
||||
asset_path,
|
||||
bundle.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
cache_key = SkillManager.get_cache_key(tenant_id, app_id, assets_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
# Build content: if summary exists, add it before the segment content
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
segment_content = segment.get_sign_content()
|
||||
|
||||
# If summary exists, prepend it to the content
|
||||
if record.summary:
|
||||
final_content = f"{record.summary}\n{segment_content}"
|
||||
else:
|
||||
final_content = segment_content
|
||||
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=final_content,
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source.content = segment.content
|
||||
# Add summary if this segment was retrieved via summary
|
||||
if hasattr(record, "summary") and record.summary:
|
||||
source.summary = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
|
||||
@ -148,7 +148,8 @@ class DockerDemuxer:
|
||||
to periodically check for errors and closed state instead of blocking forever.
|
||||
"""
|
||||
if self._error:
|
||||
raise TransportEOFError(f"Demuxer error: {self._error}") from self._error
|
||||
error = cast(BaseException, self._error)
|
||||
raise TransportEOFError(f"Demuxer error: {error}") from error
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -163,7 +164,8 @@ class DockerDemuxer:
|
||||
if self._closed:
|
||||
raise TransportEOFError("Demuxer closed")
|
||||
if self._error:
|
||||
raise TransportEOFError(f"Demuxer error: {self._error}") from self._error
|
||||
error = cast(BaseException, self._error)
|
||||
raise TransportEOFError(f"Demuxer error: {error}") from error
|
||||
# No error, continue waiting
|
||||
|
||||
def close(self) -> None:
|
||||
@ -292,6 +294,8 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
# Import Docker SDK lazily so it is loaded after gevent monkey-patching.
|
||||
import docker.errors
|
||||
|
||||
import docker
|
||||
|
||||
docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK)
|
||||
@ -364,6 +368,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
NOTE: I guess nobody will use more than 5 different docker sockets in practice....
|
||||
"""
|
||||
import docker
|
||||
|
||||
return docker.DockerClient(base_url=docker_sock)
|
||||
|
||||
@classmethod
|
||||
@ -373,6 +378,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
Get the Docker low-level API client.
|
||||
"""
|
||||
import docker
|
||||
|
||||
return docker.APIClient(base_url=docker_sock)
|
||||
|
||||
def get_docker_sock(self) -> str:
|
||||
@ -431,6 +437,12 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
return self._container_path(path)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""Upload a file to the container.
|
||||
|
||||
Files and intermediate directories are created with world-writable permissions
|
||||
(0o777 for directories, 0o666 for files) to avoid permission issues when the container
|
||||
runs as a non-root user but Docker's put_archive creates files as root.
|
||||
"""
|
||||
container = self._get_container()
|
||||
normalized = PurePosixPath(path)
|
||||
|
||||
@ -442,6 +454,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
tar_info = tarfile.TarInfo(name=file_name)
|
||||
tar_info.size = len(payload)
|
||||
tar_info.mode = 0o666
|
||||
tar.addfile(tar_info, BytesIO(payload))
|
||||
tar_stream.seek(0)
|
||||
container.put_archive(parent_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
|
||||
@ -454,8 +467,18 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
payload = content.getvalue()
|
||||
tar_stream = BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
# Add intermediate directories with proper permissions
|
||||
for i in range(len(relative_path.parts) - 1):
|
||||
dir_path = PurePosixPath(*relative_path.parts[: i + 1])
|
||||
dir_info = tarfile.TarInfo(name=dir_path.as_posix() + "/")
|
||||
dir_info.type = tarfile.DIRTYPE
|
||||
dir_info.mode = 0o777
|
||||
tar.addfile(dir_info)
|
||||
|
||||
# Add the file
|
||||
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
|
||||
tar_info.size = len(payload)
|
||||
tar_info.mode = 0o666
|
||||
tar.addfile(tar_info, BytesIO(payload))
|
||||
tar_stream.seek(0)
|
||||
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
|
||||
@ -479,7 +502,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
return BytesIO(extracted.read())
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
import docker
|
||||
import docker.errors
|
||||
|
||||
container = self._get_container()
|
||||
container_path = self._container_path(directory_path)
|
||||
@ -525,7 +548,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
pass
|
||||
|
||||
def release_environment(self) -> None:
|
||||
import docker
|
||||
import docker.errors
|
||||
|
||||
try:
|
||||
container = self._get_container()
|
||||
|
||||
@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
document.need_summary = True
|
||||
else:
|
||||
document.need_summary = False
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info("No segments need summary generation for document %s", document.id)
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment.id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
logger.info(
|
||||
"Successfully generated summary index for %s segments in document %s",
|
||||
len(segments_to_process),
|
||||
document.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document.id)
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(
|
||||
"Queuing summary index generation task for document %s (production mode)",
|
||||
document.id,
|
||||
)
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info("Summary index generation task queued for document %s", document.id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset.id,
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode, if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise KnowledgeIndexNodeError(error_summary)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output["preview"]),
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset | None = None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
|
||||
matched_count,
|
||||
dataset.id,
|
||||
document.id,
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.get_sign_content()
|
||||
# Add summary if available
|
||||
if record.summary:
|
||||
source["summary"] = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
|
||||
@ -522,7 +522,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=False,
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
@ -1093,6 +1092,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if "content" not in item:
|
||||
raise InvalidContextStructureError(f"Invalid context structure: {item}")
|
||||
|
||||
if item.get("summary"):
|
||||
context_str += item["summary"] + "\n"
|
||||
context_str += item["content"] + "\n"
|
||||
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
@ -1154,6 +1155,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
files=context_dict.get("files"),
|
||||
summary=context_dict.get("summary"),
|
||||
)
|
||||
|
||||
return source
|
||||
@ -1915,6 +1917,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
result: LLMGenerationData | None = None
|
||||
|
||||
# FIXME(Mairuis): Async processing for bash session.
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from .zip_sandbox import SandboxDownloadItem, SandboxFile, ZipSandbox
|
||||
from .zip_sandbox import SandboxDownloadItem, SandboxFile, SandboxUploadItem, ZipSandbox
|
||||
|
||||
__all__ = [
|
||||
"SandboxDownloadItem",
|
||||
"SandboxFile",
|
||||
"SandboxUploadItem",
|
||||
"ZipSandbox",
|
||||
]
|
||||
|
||||
@ -27,10 +27,20 @@ from .strategy import ZipStrategy
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxDownloadItem:
|
||||
"""Item for downloading: URL -> sandbox path."""
|
||||
|
||||
url: str
|
||||
path: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxUploadItem:
|
||||
"""Item for uploading: sandbox path -> URL."""
|
||||
|
||||
path: str
|
||||
url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxFile:
|
||||
"""A handle to a file in the sandbox."""
|
||||
@ -210,25 +220,6 @@ class ZipSandbox:
|
||||
|
||||
# ========== Download operations ==========
|
||||
|
||||
def download(self, urls: list[str], *, dest_dir: str = ".") -> list[str]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
dest_dir = self._normalize_path(dest_dir)
|
||||
paths = [self._dest_path_for_url(dest_dir, u) for u in urls]
|
||||
|
||||
p = pipeline(self.vm)
|
||||
p.add(["mkdir", "-p", dest_dir], error_message="Failed to create download directory")
|
||||
for url, out_path in zip(urls, paths, strict=True):
|
||||
p.add(["curl", "-fsSL", url, "-o", out_path], error_message="Failed to download file")
|
||||
|
||||
try:
|
||||
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return paths
|
||||
|
||||
def download_items(self, items: list[SandboxDownloadItem], *, dest_dir: str = ".") -> list[str]:
|
||||
if not items:
|
||||
return []
|
||||
@ -286,6 +277,32 @@ class ZipSandbox:
|
||||
except CommandExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
def upload_items(self, items: list[SandboxUploadItem], *, src_dir: str = ".") -> None:
|
||||
"""Upload multiple files from sandbox to target URLs.
|
||||
|
||||
Args:
|
||||
items: List of SandboxUploadItem(path, url)
|
||||
src_dir: Base directory containing the files
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
src_dir = self._normalize_path(src_dir)
|
||||
p = pipeline(self.vm)
|
||||
|
||||
for item in items:
|
||||
rel = self._normalize_path(item.path)
|
||||
src_path = posixpath.join(src_dir, rel) if src_dir not in ("", ".") else rel
|
||||
p.add(
|
||||
["curl", "-fsSL", "-X", "PUT", "-T", src_path, item.url],
|
||||
error_message=f"Failed to upload {item.path}",
|
||||
)
|
||||
|
||||
try:
|
||||
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
# ========== Archive operations ==========
|
||||
|
||||
def zip(self, src: str = ".", *, include_base: bool = True) -> SandboxFile:
|
||||
|
||||
@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
|
||||
@ -1,71 +1,56 @@
|
||||
"""Storage wrapper that provides presigned URL support with fallback to signed proxy URLs."""
|
||||
"""Storage wrapper that provides presigned URL support with fallback to ticket-based URLs.
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
This is the unified presign wrapper for all storage operations. When the underlying
|
||||
storage backend doesn't support presigned URLs (raises NotImplementedError), it falls
|
||||
back to generating ticket-based URLs that route through Dify's file proxy endpoints.
|
||||
|
||||
Usage:
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
# Wrap any BaseStorage to add presign support
|
||||
presign_storage = FilePresignStorage(base_storage)
|
||||
download_url = presign_storage.get_download_url("path/to/file.txt", expires_in=3600)
|
||||
upload_url = presign_storage.get_upload_url("path/to/file.txt", expires_in=3600)
|
||||
|
||||
When the underlying storage doesn't support presigned URLs, the fallback URLs follow the format:
|
||||
{FILES_URL}/files/storage-tickets/{token}
|
||||
|
||||
The token is a UUID that maps to the real storage key in Redis.
|
||||
"""
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.storage_wrapper import StorageWrapper
|
||||
|
||||
|
||||
class FilePresignStorage(StorageWrapper):
|
||||
"""Storage wrapper that provides presigned URL support.
|
||||
"""Storage wrapper that provides presigned URL support with ticket fallback.
|
||||
|
||||
If the wrapped storage supports presigned URLs, delegates to it.
|
||||
Otherwise, generates signed proxy URLs for download.
|
||||
Otherwise, generates ticket-based URLs for both download and upload operations.
|
||||
"""
|
||||
|
||||
SIGNATURE_PREFIX = "storage-download"
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
"""Get a presigned download URL, falling back to ticket URL if not supported."""
|
||||
try:
|
||||
return super().get_download_url(filename, expires_in)
|
||||
return self._storage.get_download_url(filename, expires_in)
|
||||
except NotImplementedError:
|
||||
return self._generate_signed_proxy_url(filename, expires_in)
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
try:
|
||||
return super().get_upload_url(filename, expires_in)
|
||||
except NotImplementedError:
|
||||
return self._generate_signed_upload_url(filename)
|
||||
return StorageTicketService.create_download_url(filename, expires_in=expires_in)
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
"""Get presigned download URLs for multiple files."""
|
||||
try:
|
||||
return super().get_download_urls(filenames, expires_in)
|
||||
return self._storage.get_download_urls(filenames, expires_in)
|
||||
except NotImplementedError:
|
||||
return [self._generate_signed_proxy_url(filename, expires_in) for filename in filenames]
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
def _generate_signed_upload_url(self, filename: str) -> str:
|
||||
# TODO: Implement this
|
||||
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
|
||||
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
|
||||
|
||||
def _generate_signed_proxy_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
base_url = dify_config.FILES_URL
|
||||
encoded_filename = urllib.parse.quote(filename, safe="")
|
||||
url = f"{base_url}/files/storage/{encoded_filename}/download"
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
"""Get a presigned upload URL, falling back to ticket URL if not supported."""
|
||||
try:
|
||||
return self._storage.get_upload_url(filename, expires_in)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = self._create_signature(filename, timestamp, nonce)
|
||||
|
||||
query = urllib.parse.urlencode({"timestamp": timestamp, "nonce": nonce, "sign": sign})
|
||||
return f"{url}?{query}"
|
||||
|
||||
@classmethod
|
||||
def _create_signature(cls, filename: str, timestamp: str, nonce: str) -> str:
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"{cls.SIGNATURE_PREFIX}|{filename}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
return base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@classmethod
|
||||
def verify_signature(cls, *, filename: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
expected_sign = cls._create_signature(filename, timestamp, nonce)
|
||||
if sign != expected_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
return StorageTicketService.create_upload_url(filename, expires_in=expires_in)
|
||||
|
||||
@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
|
||||
"score_threshold_enabled": fields.Boolean,
|
||||
"score_threshold": fields.Float,
|
||||
}
|
||||
|
||||
dataset_summary_index_fields = {
|
||||
"enable": fields.Boolean,
|
||||
"model_name": fields.String,
|
||||
"model_provider_name": fields.String,
|
||||
"summary_prompt": fields.String,
|
||||
}
|
||||
|
||||
external_retrieval_model_fields = {
|
||||
"top_k": fields.Integer,
|
||||
"score_threshold": fields.Float,
|
||||
@ -83,6 +91,7 @@ dataset_detail_fields = {
|
||||
"embedding_model_provider": fields.String,
|
||||
"embedding_available": fields.Boolean,
|
||||
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
|
||||
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"doc_form": fields.String,
|
||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||
|
||||
@ -33,6 +33,11 @@ document_fields = {
|
||||
"hit_count": fields.Integer,
|
||||
"doc_form": fields.String,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status:
|
||||
# "SUMMARIZING" (when task is queued and generating)
|
||||
"summary_index_status": fields.String,
|
||||
# Whether this document needs summary index generation
|
||||
"need_summary": fields.Boolean,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
@ -60,6 +65,10 @@ document_with_segments_fields = {
|
||||
"completed_segments": fields.Integer,
|
||||
"total_segments": fields.Integer,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status:
|
||||
# "SUMMARIZING" (when task is queued and generating)
|
||||
"summary_index_status": fields.String,
|
||||
"need_summary": fields.Boolean, # Whether this document needs summary index generation
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
|
||||
@ -58,4 +58,5 @@ hit_testing_record_fields = {
|
||||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
"summary": fields.String, # Summary content if retrieved via summary index
|
||||
}
|
||||
|
||||
@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel):
|
||||
segment_position: int | None = None
|
||||
index_node_hash: str | None = None
|
||||
content: str | None = None
|
||||
summary: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
|
||||
@ -49,4 +49,5 @@ segment_fields = {
|
||||
"stopped_at": TimestampField,
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"attachments": fields.List(fields.Nested(attachment_fields)),
|
||||
"summary": fields.String, # Summary content for the segment
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""sandbox_providers
|
||||
|
||||
Revision ID: aab323465866
|
||||
Revises: 9d77545f524e
|
||||
Revises: 788d3099ae3a
|
||||
Create Date: 2026-01-08 10:31:05.062722
|
||||
|
||||
"""
|
||||
@ -11,7 +11,7 @@ import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'aab323465866'
|
||||
down_revision = '9d77545f524e'
|
||||
down_revision = '788d3099ae3a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@ -0,0 +1,107 @@
|
||||
"""add summary index feature
|
||||
|
||||
Revision ID: 788d3099ae3a
|
||||
Revises: 9d77545f524e
|
||||
Create Date: 2026-01-27 18:15:45.277928
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '788d3099ae3a'
|
||||
down_revision = '9d77545f524e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
if _is_pg(conn):
|
||||
op.create_table('document_segment_summaries',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('summary_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey')
|
||||
)
|
||||
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
op.create_table(
|
||||
'document_segment_summaries',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('summary_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'),
|
||||
)
|
||||
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_column('need_summary')
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('summary_index_setting')
|
||||
|
||||
with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_segment_summaries_status_idx')
|
||||
batch_op.drop_index('document_segment_summaries_document_id_idx')
|
||||
batch_op.drop_index('document_segment_summaries_dataset_id_idx')
|
||||
batch_op.drop_index('document_segment_summaries_chunk_id_idx')
|
||||
|
||||
op.drop_table('document_segment_summaries')
|
||||
# ### end Alembic commands ###
|
||||
@ -72,6 +72,7 @@ class Dataset(Base):
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
|
||||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
@ -419,6 +420,7 @@ class Document(Base):
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
|
||||
|
||||
@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base):
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DocumentSegmentSummary(Base):
|
||||
__tablename__ = "document_segment_summaries"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
|
||||
sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_segment_summaries_document_id_idx", "document_id"),
|
||||
sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"),
|
||||
sa.Index("document_segment_summaries_status_idx", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# corresponds to DocumentSegment.id or parent chunk id
|
||||
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"
|
||||
|
||||
@ -659,16 +659,22 @@ class AccountTrialAppRecord(Base):
|
||||
return user
|
||||
|
||||
|
||||
class ExporleBanner(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
content = mapped_column(sa.JSON, nullable=False)
|
||||
link = mapped_column(String(255), nullable=False)
|
||||
sort = mapped_column(sa.Integer, nullable=False)
|
||||
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
language: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
|
||||
)
|
||||
|
||||
|
||||
class OAuthProviderApp(TypeBase):
|
||||
|
||||
@ -181,6 +181,7 @@ dev = [
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
||||
@ -18,7 +18,6 @@ from core.app_assets.storage import AppAssetStorage, AssetPath
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.silent_storage import SilentStorage
|
||||
from models.app_asset import AppAssets
|
||||
from models.model import App
|
||||
|
||||
@ -43,9 +42,12 @@ class AppAssetService:
|
||||
|
||||
This method creates an AppAssetStorage each time it's called,
|
||||
ensuring storage.storage_runner is only accessed after init_app.
|
||||
|
||||
The storage is wrapped with FilePresignStorage for presign fallback support
|
||||
and CachedPresignStorage for URL caching.
|
||||
"""
|
||||
return AppAssetStorage(
|
||||
storage=SilentStorage(storage.storage_runner),
|
||||
storage=storage.storage_runner,
|
||||
redis_client=redis_client,
|
||||
cache_key_prefix="app_assets",
|
||||
)
|
||||
@ -54,6 +56,22 @@ class AppAssetService:
|
||||
def _lock(app_id: str):
|
||||
return redis_client.lock(f"app_asset:lock:{app_id}", timeout=AppAssetService._LOCK_TIMEOUT_SECONDS)
|
||||
|
||||
@staticmethod
|
||||
def get_assets_by_version(tenant_id: str, app_id: str, workflow_id: str | None = None) -> AppAssets:
|
||||
"""Get asset tree by workflow_id (published) or draft if workflow_id is None."""
|
||||
with Session(db.engine) as session:
|
||||
version = workflow_id or AppAssets.VERSION_DRAFT
|
||||
assets = (
|
||||
session.query(AppAssets)
|
||||
.filter(
|
||||
AppAssets.tenant_id == tenant_id,
|
||||
AppAssets.app_id == app_id,
|
||||
AppAssets.version == version,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return assets or AppAssets(tenant_id=tenant_id, app_id=app_id, version=version)
|
||||
|
||||
@staticmethod
|
||||
def get_draft_assets(tenant_id: str, app_id: str) -> list[AssetItem]:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@ -1,26 +1,47 @@
|
||||
"""Service for exporting and importing App Bundles (DSL + assets).
|
||||
|
||||
Bundle structure:
|
||||
bundle.zip/
|
||||
{app_name}.yml # DSL file
|
||||
manifest.json # Asset manifest (required for import)
|
||||
{app_name}/ # Asset files
|
||||
folder/file.txt
|
||||
...
|
||||
|
||||
Import flow (sandbox-based):
|
||||
1. prepare_import: Frontend gets upload URL, stores import_id in Redis
|
||||
2. Frontend uploads zip to storage
|
||||
3. confirm_import: Sandbox downloads zip, extracts, uploads assets via presigned URLs
|
||||
|
||||
Manifest format (schema_version 1.0):
|
||||
- app_assets.tree: Full AppAssetFileTree for 100% ID restoration
|
||||
- files: node_id -> path mapping for file nodes
|
||||
- integrity.file_count: Basic validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_bundle_entities import (
|
||||
BUNDLE_DSL_FILENAME_PATTERN,
|
||||
BUNDLE_MAX_SIZE,
|
||||
MANIFEST_FILENAME,
|
||||
BundleExportResult,
|
||||
BundleFormatError,
|
||||
ZipSecurityError,
|
||||
BundleManifest,
|
||||
)
|
||||
from core.app_assets.storage import AssetPath
|
||||
from core.app_bundle import SourceZipExtractor
|
||||
from core.zip_sandbox import SandboxDownloadItem, ZipSandbox
|
||||
from core.app_assets.storage import AppAssetStorage, AssetPath, BundleImportZipPath
|
||||
from core.zip_sandbox import SandboxDownloadItem, SandboxUploadItem, ZipSandbox
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
from .app_asset_package_service import AppAssetPackageService
|
||||
from .app_asset_service import AppAssetService
|
||||
@ -28,6 +49,15 @@ from .app_dsl_service import AppDslService, Import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMPORT_REDIS_PREFIX = "app_bundle:import:"
|
||||
_IMPORT_TTL_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPrepareResult:
|
||||
import_id: str
|
||||
upload_url: str
|
||||
|
||||
|
||||
class AppBundleService:
|
||||
@staticmethod
|
||||
@ -38,14 +68,10 @@ class AppBundleService:
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
):
|
||||
"""
|
||||
Publish App Bundle (workflow + assets).
|
||||
Coordinates WorkflowService and AppAssetService publishing in a single transaction.
|
||||
"""
|
||||
"""Publish App Bundle (workflow + assets) in a single transaction."""
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# 1. Publish workflow
|
||||
workflow: Workflow = WorkflowService().publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -53,17 +79,16 @@ class AppBundleService:
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
)
|
||||
|
||||
# 2. Publish assets (bound to workflow_id)
|
||||
AppAssetPackageService.publish(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account_id=account.id,
|
||||
workflow_id=workflow.id,
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
# ========== Export ==========
|
||||
|
||||
@staticmethod
|
||||
def export_bundle(
|
||||
*,
|
||||
@ -73,14 +98,14 @@ class AppBundleService:
|
||||
workflow_id: str | None = None,
|
||||
expires_in: int = 10 * 60,
|
||||
) -> BundleExportResult:
|
||||
"""Export bundle and return a temporary download URL.
|
||||
|
||||
Uses sandbox VM to build the ZIP, avoiding memory pressure in API process.
|
||||
"""
|
||||
"""Export bundle with manifest.json and return a temporary download URL."""
|
||||
tenant_id = app_model.tenant_id
|
||||
app_id = app_model.id
|
||||
safe_name = AppBundleService._sanitize_filename(app_model.name)
|
||||
filename = f"{safe_name}.zip"
|
||||
|
||||
dsl_filename = f"{safe_name}.yml"
|
||||
app_assets = AppAssetService.get_assets_by_version(tenant_id, app_id, workflow_id)
|
||||
manifest = BundleManifest.from_tree(app_assets.asset_tree, dsl_filename)
|
||||
|
||||
export_id = uuid4().hex
|
||||
export_path = AssetPath.bundle_export_zip(tenant_id, app_id, export_id)
|
||||
@ -95,147 +120,170 @@ class AppBundleService:
|
||||
|
||||
with ZipSandbox(tenant_id=tenant_id, user_id=account_id, app_id="app-bundle-export") as zs:
|
||||
zs.write_file(f"bundle_root/{safe_name}.yml", dsl_content.encode("utf-8"))
|
||||
zs.write_file(f"bundle_root/{MANIFEST_FILENAME}", manifest.model_dump_json(indent=2).encode("utf-8"))
|
||||
|
||||
# Published assets: use stored source zip and unzip into <safe_name>/...
|
||||
if workflow_id is not None:
|
||||
source_zip_path = AssetPath.source_zip(tenant_id, app_id, workflow_id)
|
||||
source_url = asset_storage.get_download_url(source_zip_path, expires_in)
|
||||
zs.download_archive(source_url, path="tmp/source_assets.zip")
|
||||
zs.unzip(archive_path="tmp/source_assets.zip", dest_dir=f"bundle_root/{safe_name}")
|
||||
else:
|
||||
# Draft assets: download individual files and place under <safe_name>/...
|
||||
asset_items = AppAssetService.get_draft_assets(tenant_id, app_id)
|
||||
asset_urls = asset_storage.get_download_urls(
|
||||
[AssetPath.draft(tenant_id, app_id, a.asset_id) for a in asset_items], expires_in
|
||||
)
|
||||
zs.download_items(
|
||||
[
|
||||
SandboxDownloadItem(url=url, path=f"{safe_name}/{a.path}")
|
||||
for a, url in zip(asset_items, asset_urls, strict=True)
|
||||
],
|
||||
dest_dir="bundle_root",
|
||||
)
|
||||
if asset_items:
|
||||
asset_urls = asset_storage.get_download_urls(
|
||||
[AssetPath.draft(tenant_id, app_id, a.asset_id) for a in asset_items], expires_in
|
||||
)
|
||||
zs.download_items(
|
||||
[
|
||||
SandboxDownloadItem(url=url, path=f"{safe_name}/{a.path}")
|
||||
for a, url in zip(asset_items, asset_urls, strict=True)
|
||||
],
|
||||
dest_dir="bundle_root",
|
||||
)
|
||||
|
||||
archive = zs.zip(src="bundle_root", include_base=False)
|
||||
zs.upload(archive, upload_url)
|
||||
|
||||
download_url = asset_storage.get_download_url(export_path, expires_in)
|
||||
return BundleExportResult(download_url=download_url, filename=filename)
|
||||
return BundleExportResult(download_url=download_url, filename=f"{safe_name}.zip")
|
||||
|
||||
# ========== Import ==========
|
||||
|
||||
@staticmethod
|
||||
def import_bundle(
|
||||
def prepare_import(tenant_id: str, account_id: str) -> ImportPrepareResult:
|
||||
"""Prepare import: generate import_id and upload URL."""
|
||||
import_id = uuid4().hex
|
||||
import_path = AssetPath.bundle_import_zip(tenant_id, import_id)
|
||||
asset_storage = AppAssetService.get_storage()
|
||||
upload_url = asset_storage.get_import_upload_url(import_path, _IMPORT_TTL_SECONDS)
|
||||
|
||||
redis_client.setex(
|
||||
f"{_IMPORT_REDIS_PREFIX}{import_id}",
|
||||
_IMPORT_TTL_SECONDS,
|
||||
json.dumps({"tenant_id": tenant_id, "account_id": account_id}),
|
||||
)
|
||||
|
||||
return ImportPrepareResult(import_id=import_id, upload_url=upload_url)
|
||||
|
||||
@staticmethod
|
||||
def confirm_import(
|
||||
import_id: str,
|
||||
account: Account,
|
||||
zip_bytes: bytes,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
icon_type: str | None = None,
|
||||
icon: str | None = None,
|
||||
icon_background: str | None = None,
|
||||
) -> Import:
|
||||
if len(zip_bytes) > BUNDLE_MAX_SIZE:
|
||||
raise BundleFormatError(f"Bundle size exceeds limit: {BUNDLE_MAX_SIZE} bytes")
|
||||
"""Confirm import: download zip in sandbox, extract, and upload assets."""
|
||||
redis_key = f"{_IMPORT_REDIS_PREFIX}{import_id}"
|
||||
redis_data = redis_client.get(redis_key)
|
||||
if not redis_data:
|
||||
raise BundleFormatError("Import session expired or not found")
|
||||
|
||||
dsl_content, assets_prefix = AppBundleService._extract_dsl_from_bundle(zip_bytes)
|
||||
import_meta = json.loads(redis_data)
|
||||
tenant_id: str = import_meta["tenant_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
import_result = dsl_service.import_app(
|
||||
if tenant_id != account.current_tenant_id:
|
||||
raise BundleFormatError("Import session tenant mismatch")
|
||||
|
||||
import_path = AssetPath.bundle_import_zip(tenant_id, import_id)
|
||||
asset_storage = AppAssetService.get_storage()
|
||||
|
||||
try:
|
||||
result = AppBundleService.import_bundle(
|
||||
tenant_id=tenant_id,
|
||||
account=account,
|
||||
import_mode="yaml-content",
|
||||
yaml_content=dsl_content,
|
||||
import_path=import_path,
|
||||
asset_storage=asset_storage,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
app_id=None,
|
||||
)
|
||||
session.commit()
|
||||
finally:
|
||||
redis_client.delete(redis_key)
|
||||
asset_storage.delete_import_zip(import_path)
|
||||
|
||||
if import_result.app_id and assets_prefix:
|
||||
AppBundleService._import_assets_from_bundle(
|
||||
zip_bytes=zip_bytes,
|
||||
assets_prefix=assets_prefix,
|
||||
app_id=import_result.app_id,
|
||||
account_id=account.id,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def import_bundle(
|
||||
*,
|
||||
tenant_id: str,
|
||||
account: Account,
|
||||
import_path: BundleImportZipPath,
|
||||
asset_storage: AppAssetStorage,
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
icon_type: str | None,
|
||||
icon: str | None,
|
||||
icon_background: str | None,
|
||||
) -> Import:
|
||||
"""Execute import in sandbox."""
|
||||
download_url = asset_storage.get_import_download_url(import_path, _IMPORT_TTL_SECONDS)
|
||||
|
||||
with ZipSandbox(tenant_id=tenant_id, user_id=account.id, app_id="app-bundle-import") as zs:
|
||||
zs.download_archive(download_url, path="import.zip")
|
||||
zs.unzip(archive_path="import.zip", dest_dir="bundle")
|
||||
|
||||
manifest_bytes = zs.read_file(f"bundle/{MANIFEST_FILENAME}")
|
||||
try:
|
||||
manifest = BundleManifest.model_validate_json(manifest_bytes)
|
||||
except ValidationError as e:
|
||||
raise BundleFormatError(f"Invalid manifest.json: {e}") from e
|
||||
|
||||
dsl_content = zs.read_file(f"bundle/{manifest.dsl_filename}").decode("utf-8")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
import_result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode="yaml-content",
|
||||
yaml_content=dsl_content,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
app_id=None,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if not import_result.app_id:
|
||||
return import_result
|
||||
|
||||
app_id = import_result.app_id
|
||||
tree = manifest.app_assets.tree
|
||||
|
||||
upload_items: list[SandboxUploadItem] = []
|
||||
for file_entry in manifest.files:
|
||||
asset_path = AssetPath.draft(tenant_id, app_id, file_entry.node_id)
|
||||
file_upload_url = asset_storage.get_upload_url(asset_path, _IMPORT_TTL_SECONDS)
|
||||
src_path = f"{manifest.assets_prefix}/{file_entry.path}"
|
||||
upload_items.append(SandboxUploadItem(path=src_path, url=file_upload_url))
|
||||
|
||||
if upload_items:
|
||||
zs.upload_items(upload_items, src_dir="bundle")
|
||||
|
||||
# Tree sizes are already set from manifest; no need to update
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
if app_model:
|
||||
AppAssetService.set_draft_assets(
|
||||
app_model=app_model,
|
||||
account_id=account.id,
|
||||
new_tree=tree,
|
||||
)
|
||||
|
||||
return import_result
|
||||
|
||||
@staticmethod
|
||||
def _extract_dsl_from_bundle(zip_bytes: bytes) -> tuple[str, str | None]:
|
||||
dsl_content: str | None = None
|
||||
dsl_filename: str | None = None
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
|
||||
for info in zf.infolist():
|
||||
if info.is_dir():
|
||||
continue
|
||||
if BUNDLE_DSL_FILENAME_PATTERN.match(info.filename):
|
||||
if dsl_content is not None:
|
||||
raise BundleFormatError("Multiple DSL files found in bundle")
|
||||
dsl_content = zf.read(info).decode("utf-8")
|
||||
dsl_filename = info.filename
|
||||
|
||||
if dsl_content is None or dsl_filename is None:
|
||||
raise BundleFormatError("No DSL file (*.yml or *.yaml) found in bundle root")
|
||||
|
||||
yaml.safe_load(dsl_content)
|
||||
|
||||
assets_prefix = dsl_filename.rsplit(".", 1)[0]
|
||||
has_assets = AppBundleService._check_assets_prefix_exists(zip_bytes, assets_prefix)
|
||||
|
||||
return dsl_content, assets_prefix if has_assets else None
|
||||
|
||||
@staticmethod
|
||||
def _check_assets_prefix_exists(zip_bytes: bytes, prefix: str) -> bool:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
|
||||
for info in zf.infolist():
|
||||
if info.filename.startswith(f"{prefix}/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _import_assets_from_bundle(
|
||||
zip_bytes: bytes,
|
||||
assets_prefix: str,
|
||||
app_id: str,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_model:
|
||||
logger.warning("App not found for asset import: %s", app_id)
|
||||
return
|
||||
|
||||
asset_storage = AppAssetService.get_storage()
|
||||
extractor = SourceZipExtractor(asset_storage)
|
||||
try:
|
||||
folders, files = extractor.extract_entries(
|
||||
zip_bytes,
|
||||
expected_prefix=f"{assets_prefix}/",
|
||||
)
|
||||
except ZipSecurityError as e:
|
||||
logger.warning("Zip security error during asset import: %s", e)
|
||||
return
|
||||
|
||||
if not folders and not files:
|
||||
return
|
||||
|
||||
new_tree = extractor.build_tree_and_save(
|
||||
folders=folders,
|
||||
files=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
AppAssetService.set_draft_assets(
|
||||
app_model=app_model,
|
||||
account_id=account_id,
|
||||
new_tree=new_tree,
|
||||
)
|
||||
# ========== Helpers ==========
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
"""Sanitize app name for use as filename."""
|
||||
safe = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", name)
|
||||
safe = safe.strip(". ")
|
||||
return safe[:100] if safe else "app"
|
||||
|
||||
@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.regenerate_summary_index_task import regenerate_summary_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
@ -211,6 +212,7 @@ class DatasetService:
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model_name: str | None = None,
|
||||
retrieval_model: RetrievalModel | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
@ -253,6 +255,8 @@ class DatasetService:
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
dataset.provider = provider
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
@ -476,6 +480,11 @@ class DatasetService:
|
||||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
|
||||
# Update summary index setting if provided
|
||||
summary_index_setting = data.get("summary_index_setting", None)
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
|
||||
# Update basic dataset properties
|
||||
dataset.name = data.get("name", dataset.name)
|
||||
dataset.description = data.get("description", dataset.description)
|
||||
@ -564,6 +573,9 @@ class DatasetService:
|
||||
# update Retrieval model
|
||||
if data.get("retrieval_model"):
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# update summary index setting
|
||||
if data.get("summary_index_setting"):
|
||||
filtered_data["summary_index_setting"] = data.get("summary_index_setting")
|
||||
# update icon info
|
||||
if data.get("icon_info"):
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
@ -572,12 +584,27 @@ class DatasetService:
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
db.session.refresh(dataset)
|
||||
|
||||
# update pipeline knowledge base node data
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
|
||||
|
||||
# Trigger vector index task if indexing technique changed
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||
# If embedding_model changed, also regenerate summary vectors
|
||||
if action == "update":
|
||||
regenerate_summary_index_task.delay(
|
||||
dataset.id,
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries.
|
||||
# The new setting will only apply to:
|
||||
# 1. New documents added after the setting change
|
||||
# 2. Manual summary generation requests
|
||||
|
||||
return dataset
|
||||
|
||||
@ -616,6 +643,7 @@ class DatasetService:
|
||||
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
|
||||
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
|
||||
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
|
||||
knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
|
||||
node["data"] = knowledge_index_node_data
|
||||
updated = True
|
||||
except Exception:
|
||||
@ -854,6 +882,54 @@ class DatasetService:
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
|
||||
@staticmethod
|
||||
def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if summary_index_setting model (model_name or model_provider_name) has changed.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
|
||||
Returns:
|
||||
bool: True if summary model changed, False otherwise
|
||||
"""
|
||||
# Check if summary_index_setting is being updated
|
||||
if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
|
||||
return False
|
||||
|
||||
new_summary_setting = data.get("summary_index_setting")
|
||||
old_summary_setting = dataset.summary_index_setting
|
||||
|
||||
# If new setting is disabled, no need to regenerate
|
||||
if not new_summary_setting or not new_summary_setting.get("enable"):
|
||||
return False
|
||||
|
||||
# If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate)
|
||||
# Note: This task only regenerates existing summaries, not generates new ones
|
||||
if not old_summary_setting:
|
||||
return False
|
||||
|
||||
# Compare model_name and model_provider_name
|
||||
old_model_name = old_summary_setting.get("model_name")
|
||||
old_model_provider = old_summary_setting.get("model_provider_name")
|
||||
new_model_name = new_summary_setting.get("model_name")
|
||||
new_model_provider = new_summary_setting.get("model_provider_name")
|
||||
|
||||
# Check if model changed
|
||||
if old_model_name != new_model_name or old_model_provider != new_model_provider:
|
||||
logger.info(
|
||||
"Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
|
||||
dataset.id,
|
||||
old_model_provider,
|
||||
old_model_name,
|
||||
new_model_provider,
|
||||
new_model_name,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
@ -889,6 +965,9 @@ class DatasetService:
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
@ -994,6 +1073,9 @@ class DatasetService:
|
||||
if dataset.keyword_number != knowledge_configuration.keyword_number:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
@ -1314,6 +1396,50 @@ class DocumentService:
|
||||
upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
@staticmethod
|
||||
def enrich_documents_with_summary_index_status(
|
||||
documents: Sequence[Document],
|
||||
dataset: Dataset,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Enrich documents with summary_index_status based on dataset summary index settings.
|
||||
|
||||
This method calculates and sets the summary_index_status for each document that needs summary.
|
||||
Documents that don't need summary or when summary index is disabled will have status set to None.
|
||||
|
||||
Args:
|
||||
documents: List of Document instances to enrich
|
||||
dataset: Dataset instance containing summary_index_setting
|
||||
tenant_id: Tenant ID for summary status lookup
|
||||
"""
|
||||
# Check if dataset has summary index enabled
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
|
||||
# Filter documents that need summary calculation
|
||||
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
|
||||
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
|
||||
|
||||
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
|
||||
summary_status_map: dict[str, str | None] = {}
|
||||
if has_summary_index and document_ids_need_summary:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_status_map = SummaryIndexService.get_documents_summary_index_status(
|
||||
document_ids=document_ids_need_summary,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Add summary_index_status to each document
|
||||
for document in documents:
|
||||
if has_summary_index and document.need_summary is True:
|
||||
# Get status from map, default to None (not queued yet)
|
||||
document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined]
|
||||
else:
|
||||
# Return null if summary index is not enabled or document doesn't need summary
|
||||
document.summary_index_status = None # type: ignore[attr-defined]
|
||||
|
||||
@staticmethod
|
||||
def prepare_document_batch_download_zip(
|
||||
*,
|
||||
@ -1964,6 +2090,8 @@ class DocumentService:
|
||||
DuplicateDocumentIndexingTaskProxy(
|
||||
dataset.tenant_id, dataset.id, duplicate_document_ids
|
||||
).delay()
|
||||
# Note: Summary index generation is triggered in document_indexing_task after indexing completes
|
||||
# to ensure segments are available. See tasks/document_indexing_task.py
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
|
||||
@ -2268,6 +2396,11 @@ class DocumentService:
|
||||
name: str,
|
||||
batch: str,
|
||||
):
|
||||
# Set need_summary based on dataset's summary_index_setting
|
||||
need_summary = False
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
need_summary = True
|
||||
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@ -2281,6 +2414,7 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
doc_language=document_language,
|
||||
need_summary=need_summary,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if dataset.built_in_field_enabled:
|
||||
@ -2505,6 +2639,7 @@ class DocumentService:
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
summary_index_setting=knowledge_config.summary_index_setting,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
@ -2686,6 +2821,14 @@ class DocumentService:
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
# valid summary index setting
|
||||
summary_index_setting = args["process_rule"].get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
|
||||
raise ValueError("Summary index model name is required")
|
||||
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
|
||||
raise ValueError("Summary index model provider name is required")
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(
|
||||
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
|
||||
@ -3154,6 +3297,35 @@ class SegmentService:
|
||||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update summary index if summary is provided and has changed
|
||||
if args.summary is not None:
|
||||
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
|
||||
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
|
||||
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if summary has changed
|
||||
existing_summary_content = existing_summary.summary_content if existing_summary else None
|
||||
if existing_summary_content != args.summary:
|
||||
# Summary has changed, update it
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
@ -3228,6 +3400,73 @@ class SegmentService:
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# Handle summary index when content changed
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if args.summary is None:
|
||||
# User didn't provide summary, auto-regenerate if segment previously had summary
|
||||
# Auto-regeneration only happens if summary_index_setting exists and enable is True
|
||||
if (
|
||||
existing_summary
|
||||
and dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
):
|
||||
# Segment previously had summary, regenerate it with new content
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, dataset.summary_index_setting
|
||||
)
|
||||
logger.info("Auto-regenerated summary for segment %s after content change", segment.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to auto-regenerate summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary regeneration fails
|
||||
else:
|
||||
# User provided summary, check if it has changed
|
||||
# Manual summary updates are allowed even if summary_index_setting doesn't exist
|
||||
existing_summary_content = existing_summary.summary_content if existing_summary else None
|
||||
if existing_summary_content != args.summary:
|
||||
# Summary has changed, use user-provided summary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
logger.info("Updated summary for segment %s with user-provided content", segment.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
else:
|
||||
# Summary hasn't changed, regenerate based on new content
|
||||
# Auto-regeneration only happens if summary_index_setting exists and enable is True
|
||||
if (
|
||||
existing_summary
|
||||
and dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
):
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, dataset.summary_index_setting
|
||||
)
|
||||
logger.info(
|
||||
"Regenerated summary for segment %s after content change (summary unchanged)",
|
||||
segment.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to regenerate summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary regeneration fails
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
except Exception as e:
|
||||
@ -3616,6 +3855,39 @@ class SegmentService:
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
@classmethod
|
||||
def get_segments_by_document_and_dataset(
|
||||
cls,
|
||||
document_id: str,
|
||||
dataset_id: str,
|
||||
status: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> Sequence[DocumentSegment]:
|
||||
"""
|
||||
Get segments for a document in a dataset with optional filtering.
|
||||
|
||||
Args:
|
||||
document_id: Document ID
|
||||
dataset_id: Dataset ID
|
||||
status: Optional status filter (e.g., "completed")
|
||||
enabled: Optional enabled filter (True/False)
|
||||
|
||||
Returns:
|
||||
Sequence of DocumentSegment instances
|
||||
"""
|
||||
query = select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(DocumentSegment.status == status)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.where(DocumentSegment.enabled == enabled)
|
||||
|
||||
return db.session.scalars(query).all()
|
||||
|
||||
|
||||
class DatasetCollectionBindingService:
|
||||
@classmethod
|
||||
|
||||
@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel):
|
||||
data_source: DataSource | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
embedding_model: str | None = None
|
||||
@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel):
|
||||
regenerate_child_chunks: bool = False
|
||||
enabled: bool | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
|
||||
@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel):
|
||||
embedding_model: str = ""
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
# add summary index setting
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -343,6 +343,9 @@ class RagPipelineDslService:
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
@ -477,6 +480,9 @@ class RagPipelineDslService:
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
|
||||
@ -10,6 +10,7 @@ from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssets
|
||||
from core.sandbox.initializer.skill_initializer import SkillInitializer
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
from extensions.ext_storage import storage
|
||||
from services.app_asset_package_service import AppAssetPackageService
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
@ -30,7 +31,7 @@ class SandboxService:
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
|
||||
|
||||
storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id)
|
||||
archive_storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id, storage.storage_runner)
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
.options(sandbox_provider.config)
|
||||
@ -40,7 +41,7 @@ class SandboxService:
|
||||
.initializer(AppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.storage(archive_storage, assets.id)
|
||||
.build()
|
||||
)
|
||||
|
||||
@ -49,8 +50,8 @@ class SandboxService:
|
||||
|
||||
@classmethod
|
||||
def delete_draft_storage(cls, tenant_id: str, user_id: str) -> None:
|
||||
storage = ArchiveSandboxStorage(tenant_id, SandboxBuilder.draft_id(user_id))
|
||||
storage.delete()
|
||||
archive_storage = ArchiveSandboxStorage(tenant_id, SandboxBuilder.draft_id(user_id), storage.storage_runner)
|
||||
archive_storage.delete()
|
||||
|
||||
@classmethod
|
||||
def create_draft(
|
||||
@ -66,7 +67,9 @@ class SandboxService:
|
||||
|
||||
AppAssetPackageService.build_assets(tenant_id, app_id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(user_id)
|
||||
storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH])
|
||||
archive_storage = ArchiveSandboxStorage(
|
||||
tenant_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
|
||||
)
|
||||
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
@ -77,7 +80,7 @@ class SandboxService:
|
||||
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.storage(archive_storage, assets.id)
|
||||
.build()
|
||||
)
|
||||
|
||||
@ -98,7 +101,9 @@ class SandboxService:
|
||||
|
||||
AppAssetPackageService.build_assets(tenant_id, app_id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(user_id)
|
||||
storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH])
|
||||
archive_storage = ArchiveSandboxStorage(
|
||||
tenant_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
|
||||
)
|
||||
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
@ -109,7 +114,7 @@ class SandboxService:
|
||||
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.storage(archive_storage, assets.id)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
159
api/services/storage_ticket_service.py
Normal file
159
api/services/storage_ticket_service.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""Storage ticket service for generating opaque download/upload URLs.
|
||||
|
||||
This service provides a ticket-based approach for file access. Instead of exposing
|
||||
the real storage key in URLs, it generates a random UUID token and stores the mapping
|
||||
in Redis with a TTL.
|
||||
|
||||
Usage:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
# Generate a download ticket
|
||||
url = StorageTicketService.create_download_url("path/to/file.txt", expires_in=300)
|
||||
|
||||
# Generate an upload ticket
|
||||
url = StorageTicketService.create_upload_url("path/to/file.txt", expires_in=300, max_bytes=10*1024*1024)
|
||||
|
||||
URL format:
|
||||
{FILES_URL}/files/storage-tickets/{token}
|
||||
|
||||
The token is validated by looking up the Redis key, which contains:
|
||||
- op: "download" or "upload"
|
||||
- storage_key: the real storage path
|
||||
- max_bytes: (upload only) maximum allowed upload size
|
||||
- filename: suggested filename for Content-Disposition header
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TICKET_KEY_PREFIX = "storage_files"
|
||||
DEFAULT_DOWNLOAD_TTL = 300 # 5 minutes
|
||||
DEFAULT_UPLOAD_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageTicket:
|
||||
"""Represents a storage access ticket."""
|
||||
|
||||
op: str # "download" or "upload"
|
||||
storage_key: str
|
||||
max_bytes: int | None = None # upload only
|
||||
filename: str | None = None # suggested filename for download
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
data = {"op": self.op, "storage_key": self.storage_key}
|
||||
if self.max_bytes is not None:
|
||||
data["max_bytes"] = str(self.max_bytes)
|
||||
if self.filename is not None:
|
||||
data["filename"] = self.filename
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "StorageTicket":
|
||||
return cls(
|
||||
op=data["op"],
|
||||
storage_key=data["storage_key"],
|
||||
max_bytes=data.get("max_bytes"),
|
||||
filename=data.get("filename"),
|
||||
)
|
||||
|
||||
|
||||
class StorageTicketService:
|
||||
"""Service for creating and validating storage access tickets."""
|
||||
|
||||
@classmethod
|
||||
def create_download_url(
|
||||
cls,
|
||||
storage_key: str,
|
||||
*,
|
||||
expires_in: int = DEFAULT_DOWNLOAD_TTL,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Create a download ticket and return the URL.
|
||||
|
||||
Args:
|
||||
storage_key: The real storage path
|
||||
expires_in: TTL in seconds (default 300)
|
||||
filename: Suggested filename for Content-Disposition header
|
||||
|
||||
Returns:
|
||||
Full URL with token
|
||||
"""
|
||||
if filename is None:
|
||||
filename = storage_key.rsplit("/", 1)[-1]
|
||||
|
||||
ticket = StorageTicket(op="download", storage_key=storage_key, filename=filename)
|
||||
token = cls._store_ticket(ticket, expires_in)
|
||||
return cls._build_url(token)
|
||||
|
||||
@classmethod
|
||||
def create_upload_url(
|
||||
cls,
|
||||
storage_key: str,
|
||||
*,
|
||||
expires_in: int = DEFAULT_UPLOAD_TTL,
|
||||
max_bytes: int = DEFAULT_MAX_UPLOAD_BYTES,
|
||||
) -> str:
|
||||
"""Create an upload ticket and return the URL.
|
||||
|
||||
Args:
|
||||
storage_key: The real storage path
|
||||
expires_in: TTL in seconds (default 300)
|
||||
max_bytes: Maximum allowed upload size in bytes
|
||||
|
||||
Returns:
|
||||
Full URL with token
|
||||
"""
|
||||
ticket = StorageTicket(op="upload", storage_key=storage_key, max_bytes=max_bytes)
|
||||
token = cls._store_ticket(ticket, expires_in)
|
||||
return cls._build_url(token)
|
||||
|
||||
@classmethod
|
||||
def get_ticket(cls, token: str) -> StorageTicket | None:
|
||||
"""Retrieve a ticket by token.
|
||||
|
||||
Args:
|
||||
token: The UUID token from the URL
|
||||
|
||||
Returns:
|
||||
StorageTicket if found and valid, None otherwise
|
||||
"""
|
||||
key = cls._ticket_key(token)
|
||||
try:
|
||||
data = redis_client.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
return StorageTicket.from_dict(json.loads(data))
|
||||
except Exception:
|
||||
logger.warning("Failed to retrieve storage ticket: %s", token, exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _store_ticket(cls, ticket: StorageTicket, ttl: int) -> str:
|
||||
"""Store a ticket in Redis and return the token."""
|
||||
token = str(uuid4())
|
||||
key = cls._ticket_key(token)
|
||||
value = json.dumps(ticket.to_dict())
|
||||
redis_client.setex(key, ttl, value)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _ticket_key(cls, token: str) -> str:
|
||||
"""Generate Redis key for a token."""
|
||||
return f"{TICKET_KEY_PREFIX}:{token}"
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, token: str) -> str:
|
||||
"""Build the full URL for a token."""
|
||||
base_url = dify_config.FILES_URL
|
||||
return f"{base_url}/files/storage-files/{token}"
|
||||
1432
api/services/summary_index_service.py
Normal file
1432
api/services/summary_index_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Enable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
|
||||
@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
|
||||
@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
|
||||
@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
@ -47,6 +47,7 @@ def delete_segment_from_index_task(
|
||||
doc_form = dataset_document.doc_form
|
||||
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
# For actual deletion, we should delete summaries (not just disable them)
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
@ -54,6 +55,7 @@ def delete_segment_from_index_task(
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
delete_summaries=True, # Actually delete summaries when segment is deleted
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
|
||||
@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str):
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
# Disable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
disabled_by=segment.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
|
||||
@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
index_node_ids.extend(attachment_ids)
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
|
||||
# Disable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
# Get disabled_by from first segment (they should all have the same disabled_by)
|
||||
disabled_by = segments[0].disabled_by if segments else None
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for segments: %s", str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
|
||||
@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
# expire all session to get latest document's indexing status
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
if (
|
||||
document.indexing_status == "completed"
|
||||
and document.doc_form != "qa_model"
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: "
|
||||
"status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
|
||||
@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
# save vector index
|
||||
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||
|
||||
# Enable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# Enable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for segments: %s", str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user