Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-03-25 15:30:21 +08:00
commit 43f0c780c3
462 changed files with 36000 additions and 8934 deletions

13
.gemini/config.yaml Normal file
View File

@ -0,0 +1,13 @@
have_fun: false
memory_config:
disabled: false
code_review:
disable: true
comment_severity_threshold: MEDIUM
max_review_comments: -1
pull_request_opened:
help: false
summary: false
code_review: false
include_drafts: false
ignore_patterns: []

View File

@ -4,10 +4,9 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
node-version-file: web/.nvmrc
working-directory: web
node-version-file: .nvmrc
cache: true
cache-dependency-path: web/pnpm-lock.yaml
run-install: |
cwd: ./web
run-install: true

View File

@ -84,20 +84,20 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
vp run lint:ci
# pnpm run lint:report
# continue-on-error: true
# - name: Annotate Code
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
# with:
# eslint-report: web/eslint_report.json
# github-token: ${{ secrets.GITHUB_TOKEN }}
run: vp run lint:ci
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
@ -114,6 +114,13 @@ jobs:
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
@ -85,7 +86,7 @@ def migrate_annotation_vector_database():
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
@ -177,7 +178,9 @@ def migrate_knowledge_vector_database():
while True:
try:
stmt = (
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
select(Dataset)
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
.order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -269,7 +272,7 @@ def migrate_knowledge_vector_database():
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == "hierarchical_model":
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []

View File

@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -594,7 +594,7 @@ class AppApi(Resource):
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon_type": args.icon_type,
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,

View File

@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:

View File

@ -4,7 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
# Create workspace if needed
if (

View File

@ -4,7 +4,7 @@ import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -180,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
@ -355,7 +356,7 @@ class DatasetListApi(Resource):
for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
@ -436,7 +437,7 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
data["embedding_model_provider"] = str(provider_id)
@ -454,7 +455,7 @@ class DatasetApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
@ -485,7 +486,7 @@ class DatasetApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == "high_quality"
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource):
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource):
_, current_tenant_id = current_account_with_tenant()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
)
)
or 0
)
if current_key_count >= self.max_keys:
@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
key = (
db.session.query(ApiToken)
key = db.session.scalar(
select(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
.limit(1)
)
if key is None:
@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.delete(key)
db.session.commit()
return {"result": "success"}, 204

View File

@ -10,7 +10,7 @@ import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -27,6 +27,7 @@ from core.model_manager import ModelManager
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from extensions.ext_database import db
@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource):
raise Forbidden(str(e))
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
dataset_process_rule = db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
if dataset_process_rule:
mode = dataset_process_rule.mode
@ -330,21 +330,23 @@ class DatasetDocumentListApi(Resource):
if fetch:
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
document.completed_segments = completed_segments
document.total_segments = total_segments
@ -448,7 +450,7 @@ class DatasetInitApi(Resource):
raise Forbidden()
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
@ -462,7 +464,7 @@ class DatasetInitApi(Resource):
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -521,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
file = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
# raise error if file not found
@ -586,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
file_detail = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
if file_detail is None:
@ -672,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@ -723,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
)
# Create a dictionary with document attributes and additional fields
@ -1258,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
log = db.session.scalar(
select(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
.limit(1)
)
if not log:
return {
@ -1328,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource):
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"

View File

@ -26,6 +26,7 @@ from controllers.console.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields))
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource):
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields))
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
if not upload_file:
raise NotFound("UploadFile not found.")
@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
pipeline = db.session.scalar(
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
)
if not pipeline:

View File

@ -15,6 +15,7 @@ from controllers.service_api.wraps import (
cloud_edition_billing_rate_limit_check,
)
from core.provider_manager import ProviderManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if (
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
):
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
)
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
if item_model in model_names:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
else:
item["embedding_available"] = False
item["embedding_available"] = False # type: ignore
else:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
response = {
"data": data,
"has_more": len(datasets) == query.limit,
@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data.get("indexing_technique") == "high_quality":
if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY:
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names:
data["embedding_available"] = True
@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource):
# check embedding model setting
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if payload.indexing_technique == "high_quality" or embedding_model_provider:
if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model

View File

@ -17,6 +17,7 @@ from controllers.service_api.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource):
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
model_manager = ModelManager()
@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(

View File

@ -4,6 +4,7 @@ from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType, ConversationFromSource
@ -50,7 +51,7 @@ class AnnotationReplyFeature:
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -19,6 +19,7 @@ class RateLimit:
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
max_active_requests: int
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
@ -27,7 +28,13 @@ class RateLimit:
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests
self.max_active_requests = max_active_requests
# Only flush here if this instance has already been fully initialized,
# i.e. the Redis key attributes exist. Otherwise, rely on the flush at
# the end of initialization below.
if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"):
self.flush_cache(use_local_value=True)
# must be called after max_active_requests is set
if self.disabled():
return
@ -41,8 +48,6 @@ class RateLimit:
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
if self.disabled():
return
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
@ -50,7 +55,8 @@ class RateLimit:
else:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
if self.disabled():
return
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return

View File

@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent
class SuspendLayer(GraphEngineLayer):
""" """
def __init__(self) -> None:
super().__init__()
self._paused = False
def on_graph_start(self):
pass
self._paused = False
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
"""
if isinstance(event, GraphRunPausedEvent):
pass
self._paused = True
def on_graph_end(self, error: Exception | None):
""" """
pass
self._paused = False
def is_paused(self) -> bool:
return self._paused

View File

@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return

View File

@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
@ -271,7 +271,7 @@ class IndexingRunner:
doc_form: str | None = None,
doc_language: str = "English",
dataset_id: str | None = None,
indexing_technique: str = "economy",
indexing_technique: str = IndexTechniqueType.ECONOMY,
) -> IndexingEstimate:
"""
Estimate the indexing for the document.
@ -289,7 +289,7 @@ class IndexingRunner:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
@ -303,7 +303,7 @@ class IndexingRunner:
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
@ -573,7 +573,7 @@ class IndexingRunner:
"""
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
@ -587,7 +587,7 @@ class IndexingRunner:
create_keyword_thread = None
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
):
# create keyword index
create_keyword_thread = threading.Thread(
@ -597,7 +597,7 @@ class IndexingRunner:
create_keyword_thread.start()
max_workers = 10
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
@ -628,7 +628,7 @@ class IndexingRunner:
tokens += future.result()
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
and create_keyword_thread is not None
):
create_keyword_thread.join()
@ -654,7 +654,7 @@ class IndexingRunner:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
@ -764,7 +764,7 @@ class IndexingRunner:
) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,

View File

@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
for msg in v
]
if isinstance(v, list)
else v,

View File

@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/decode/from_identifier",
PluginDecodeResponse,
data={"plugin_unique_identifier": plugin_unique_identifier},
headers={"Content-Type": "application/json"},
params={"plugin_unique_identifier": plugin_unique_identifier},
)
def fetch_plugin_installation_by_ids(

View File

@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector):
)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
docs.append(doc)
return docs

View File

@ -6,6 +6,7 @@ from typing import Any
from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import AttachmentDocument, Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
@ -71,7 +72,7 @@ class DatasetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == "high_quality":
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,

View File

@ -9,6 +9,7 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
@ -159,7 +160,7 @@ class IndexProcessor:
tenant_id = dataset.tenant_id
preview_output = self.format_preview(chunk_structure, chunks)
if indexing_technique != "high_quality":
if indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):

View File

@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
with_keywords: bool = True,
**kwargs,
) -> None:
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
keyword = Keyword(dataset)
keyword.add_texts(documents)

View File

@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
with_keywords: bool = True,
**kwargs,
) -> None:
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
for document in documents:
child_documents = document.children
@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
vector = Vector(dataset)
@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
all_child_documents = []
all_multimodal_documents = []
for doc in documents:

View File

@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
with_keywords: bool = True,
**kwargs,
) -> None:
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor):
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
vector = Vector(dataset)
vector.create(documents)
else:

View File

@ -675,7 +675,7 @@ class DatasetRetrieval:
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if selected_dataset.indexing_technique == "economy":
if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY:
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
@ -752,7 +752,7 @@ class DatasetRetrieval:
"The configured knowledge base list have different indexing technique, please set reranking model."
)
index_type = available_datasets[0].indexing_technique
if index_type == "high_quality":
if index_type == IndexTechniqueType.HIGH_QUALITY:
embedding_model_check = all(
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
)
@ -1068,7 +1068,7 @@ class DatasetRetrieval:
else default_retrieval_model
)
if dataset.indexing_technique == "economy":
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,

View File

@ -2,6 +2,7 @@ import concurrent.futures
import logging
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
@ -21,7 +22,7 @@ class SummaryIndex:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset or dataset.indexing_technique != "high_quality":
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
if summary_index_setting is None:

View File

@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,

View File

@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
if dataset.indexing_technique != IndexTechniqueType.ECONOMY:
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from collections.abc import Sequence
import json
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.model_manager import ModelInstance
@ -36,6 +39,11 @@ from .exc import (
)
from .protocols import TemplateRenderer
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
MAX_RESOLVED_VALUE_LENGTH = 1024
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
@ -475,3 +483,61 @@ def _append_file_prompts(
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
"""Try to restore the original type from a resolved template string.
Variable references are always resolved to text, but completion params may
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
the ``temperature`` parameter). This helper attempts a JSON parse so that
``"0.7"`` ``0.7``, ``"true"`` ``True``, etc. Plain strings that are not
valid JSON literals are returned as-is.
"""
stripped = raw.strip()
if not stripped:
return raw
try:
parsed: object = json.loads(stripped)
except (json.JSONDecodeError, ValueError):
return raw
if isinstance(parsed, (int, float, bool)):
return parsed
return raw
def resolve_completion_params_variables(
completion_params: Mapping[str, Any],
variable_pool: VariablePool,
) -> dict[str, Any]:
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
Security notes:
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
prevent denial-of-service through excessively large variable payloads.
- This follows the same ``VariablePool.convert_template`` pattern used across
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
model plugin receives these values as structured JSON key-value pairs they
are never concatenated into raw HTTP headers or SQL queries.
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
restored to their native type rather than sent as a bare string.
"""
resolved: dict[str, Any] = {}
for key, value in completion_params.items():
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
segment_group = variable_pool.convert_template(value)
text = segment_group.text
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
logger.warning(
"Resolved value for param '%s' truncated from %d to %d chars",
key,
len(text),
MAX_RESOLVED_VALUE_LENGTH,
)
text = text[:MAX_RESOLVED_VALUE_LENGTH]
resolved[key] = _coerce_resolved_value(text)
else:
resolved[key] = value
return resolved

View File

@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]):
# fetch model config
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop

View File

@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")

View File

@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""

View File

@ -125,7 +125,8 @@ class BroadcastChannel(Protocol):
a specific topic, all subscription should receive the published message.
There are no restriction for the persistence of messages. Once a subscription is created, it
should receive all subsequent messages published.
should receive all subsequent messages published. However, a subscription should not receive
any message published before the subscription is established.
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
"""

View File

@ -64,7 +64,10 @@ class _StreamsSubscription(Subscription):
self._client = client
self._key = key
self._closed = threading.Event()
self._last_id = "0-0"
# Setting initial last id to `$` to signal redis that we only want new messages.
#
# ref: https://redis.io/docs/latest/commands/xread/#the-special--id
self._last_id = "$"
self._queue: queue.Queue[object] = queue.Queue()
self._start_lock = threading.Lock()
self._listener: threading.Thread | None = None

View File

@ -18,15 +18,23 @@ if TYPE_CHECKING:
from models.model import EndUser
def _resolve_current_user() -> EndUser | Account | None:
"""
Resolve the current user proxy to its underlying user object.
This keeps unit tests working when they patch `current_user` directly
instead of bootstrapping a full Flask-Login manager.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
user = _resolve_current_user()
if not isinstance(user, Account):
raise ValueError("current_user must be an Account instance")
@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _get_user()
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, user.id)

View File

@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
@ -137,7 +137,7 @@ class Dataset(Base):
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255))
index_struct = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -496,7 +496,9 @@ class Document(Base):
)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_form: Mapped[IndexStructureType] = mapped_column(
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
)
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))

View File

@ -589,7 +589,9 @@ class AppModelConfig(TypeBase):
__tablename__ = "app_model_configs"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
@ -938,7 +940,9 @@ class AccountTrialAppRecord(Base):
class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@ -1847,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source: Mapped[str] = mapped_column(LongText, nullable=False)

View File

@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase):
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema: Mapped[str] = mapped_column(LongText, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
EnumText(ApiProviderSchemaType, length=40), nullable=False
)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id

View File

@ -174,7 +174,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.55.0",
"pyrefly>=0.57.1",
]
############################################################

View File

@ -241,7 +241,7 @@ class AppService:
class ArgsDict(TypedDict):
name: str
description: str
icon_type: str
icon_type: IconType | str | None
icon: str
icon_background: str
use_icon_as_answer_icon: bool
@ -257,7 +257,13 @@ class AppService:
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
icon_type = args.get("icon_type")
if icon_type is None:
resolved_icon_type = app.icon_type
else:
resolved_icon_type = IconType(icon_type)
app.icon_type = resolved_icon_type
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)

View File

@ -1,8 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
class AuthCredentials(TypedDict):
auth_type: str
config: dict[str, Any]
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
self.credentials = credentials
@abstractmethod

View File

@ -1,9 +1,9 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
def __init__(self, provider: str, credentials: AuthCredentials):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.file import helpers as file_helpers
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
@ -228,7 +228,7 @@ class DatasetService:
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
if embedding_model_provider and embedding_model_name:
# check if embedding model setting is valid
@ -254,7 +254,10 @@ class DatasetService:
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
dataset = Dataset(
name=name,
indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None,
)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.description = description
dataset.created_by = account.id
@ -349,7 +352,7 @@ class DatasetService:
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -717,13 +720,13 @@ class DatasetService:
if "indexing_technique" not in data:
return None
if dataset.indexing_technique != data["indexing_technique"]:
if data["indexing_technique"] == "economy":
if data["indexing_technique"] == IndexTechniqueType.ECONOMY:
# Remove embedding model configuration for economy mode
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
return "remove"
elif data["indexing_technique"] == "high_quality":
elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
# Configure embedding model for high quality mode
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
return "add"
@ -953,8 +956,8 @@ class DatasetService:
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
dataset.indexing_technique = knowledge_configuration.indexing_technique
if knowledge_configuration.indexing_technique == "high_quality":
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, # ignore type error
@ -976,7 +979,7 @@ class DatasetService:
embedding_model_name,
)
dataset.collection_binding_id = dataset_collection_binding.id
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
else:
raise ValueError("Invalid index method")
@ -991,9 +994,9 @@ class DatasetService:
action = None
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
# if update indexing_technique
if knowledge_configuration.indexing_technique == "economy":
if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_configuration.indexing_technique == "high_quality":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
action = "add"
# get embedding model setting
try:
@ -1018,7 +1021,7 @@ class DatasetService:
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@ -1029,7 +1032,7 @@ class DatasetService:
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
skip_embedding_update = False
try:
# Handle existing model provider
@ -1089,7 +1092,7 @@ class DatasetService:
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
@ -1440,7 +1443,7 @@ class DocumentService:
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
@ -1907,8 +1910,8 @@ class DocumentService:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique)
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
@ -2040,7 +2043,7 @@ class DocumentService:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_form = IndexStructureType(knowledge_config.doc_form)
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
@ -2640,7 +2643,7 @@ class DocumentService:
document.splitting_completed_at = None
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = document_data.doc_form
document.doc_form = IndexStructureType(document_data.doc_form)
db.session.add(document)
db.session.commit()
# update document segment
@ -2689,7 +2692,7 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
assert knowledge_config.embedding_model_provider
assert knowledge_config.embedding_model
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -2712,7 +2715,7 @@ class DocumentService:
tenant_id=tenant_id,
name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type,
indexing_technique=knowledge_config.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique),
created_by=account.id,
embedding_model=knowledge_config.embedding_model,
embedding_model_provider=knowledge_config.embedding_model_provider,
@ -3101,7 +3104,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
if not args["answer"].strip():
@ -3125,7 +3128,7 @@ class SegmentService:
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -3158,7 +3161,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
@ -3208,7 +3211,7 @@ class SegmentService:
try:
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -3230,9 +3233,9 @@ class SegmentService:
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
@ -3255,7 +3258,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
@ -3322,7 +3325,7 @@ class SegmentService:
content = args.content or segment.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3345,7 +3348,7 @@ class SegmentService:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@ -3382,7 +3385,7 @@ class SegmentService:
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
@ -3409,7 +3412,7 @@ class SegmentService:
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
@ -3419,7 +3422,7 @@ class SegmentService:
)
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
else:
@ -3436,7 +3439,7 @@ class SegmentService:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3449,7 +3452,7 @@ class SegmentService:
db.session.commit()
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@ -3481,7 +3484,7 @@ class SegmentService:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle summary index when content changed
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
from models.dataset import DocumentSegmentSummary
existing_summary = (

View File

@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.helper.name_generator import generate_incremental_name
from core.plugin.entities.plugin import PluginDependency
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
@ -311,13 +312,13 @@ class RagPipelineDslService:
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
@ -343,7 +344,7 @@ class RagPipelineDslService:
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
@ -443,18 +444,18 @@ class RagPipelineDslService:
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
@ -480,7 +481,7 @@ class RagPipelineDslService:
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
dataset.keyword_number = knowledge_configuration.keyword_number
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
@ -772,7 +773,7 @@ class RagPipelineDslService:
)
case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE:
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@ -79,9 +80,9 @@ class RagPipelineTransformService:
pipeline = self._create_pipeline(pipeline_yaml)
# save chunk structure to dataset
if doc_form == "hierarchical_model":
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset.chunk_structure = "hierarchical_model"
elif doc_form == "text_model":
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
dataset.chunk_structure = "text_model"
else:
raise ValueError("Unsupported doc form")
@ -101,38 +102,38 @@ class RagPipelineTransformService:
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case DataSourceType.NOTION_IMPORT:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case DataSourceType.WEBSITE_CRAWL:
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
# get graph from transform.website-crawl-general-economy.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml
@ -169,11 +170,11 @@ class RagPipelineTransformService:
):
knowledge_configuration_dict = node.get("data", {})
if indexing_technique == "high_quality":
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
if indexing_technique == "economy":
if indexing_technique == IndexTechniqueType.ECONOMY:
retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
knowledge_configuration.retrieval_model = retrieval_model
else:

View File

@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.models.document import Document
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
@ -140,7 +141,7 @@ class SummaryIndexService:
session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one.
If not provided, creates a new session and commits automatically.
"""
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.warning(
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
dataset.id,
@ -724,7 +725,7 @@ class SummaryIndexService:
List of created DocumentSegmentSummary instances
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
dataset.id,
@ -851,7 +852,7 @@ class SummaryIndexService:
)
# Remove from vector database (but keep records)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
try:
@ -889,7 +890,7 @@ class SummaryIndexService:
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
"""
# Only enable summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
with session_factory.create_session() as session:
@ -981,7 +982,7 @@ class SummaryIndexService:
return
# Delete from vector database
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
vector = Vector(dataset)
@ -1012,7 +1013,7 @@ class SummaryIndexService:
Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
"""
# Only update summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return None
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist

View File

@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, Document
@ -45,7 +45,7 @@ class VectorService:
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
model_manager = ModelManager()
@ -112,7 +112,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
@ -197,7 +197,7 @@ class VectorService:
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@ -237,7 +237,7 @@ class VectorService:
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
@ -252,7 +252,7 @@ class VectorService:
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
attachments = segment.attachments

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -36,7 +37,7 @@ def add_annotation_to_index_task(
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
collection_binding_id=dataset_collection_binding.id,
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import exists, select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
collection_binding_id=app_annotation_setting.collection_binding_id,
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -64,7 +65,7 @@ def enable_annotation_reply_task(
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
@ -93,7 +94,7 @@ def enable_annotation_reply_task(
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -5,6 +5,7 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@ -37,7 +38,7 @@ def update_annotation_to_index_task(
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,

View File

@ -11,6 +11,7 @@ from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
@ -119,7 +120,7 @@ def batch_create_segment_to_index_task(
document_segments = []
embedding_model = None
if dataset_config["indexing_technique"] == "high_quality":
if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset_config["tenant_id"],
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
@ -126,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
logger.warning("Dataset %s not found after indexing", dataset_id)
return
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
summary_index_setting = dataset.summary_index_setting
if summary_index_setting and summary_index_setting.get("enable"):
# expire all session to get latest document's indexing status
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
)
if (
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.doc_form != IndexStructureType.QA_INDEX
and document.need_summary is True
):
try:

View File

@ -7,6 +7,7 @@ import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids:
return
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
click.style(
f"Skipping summary generation for dataset {dataset_id}: "

View File

@ -9,6 +9,7 @@ from celery import shared_task
from sqlalchemy import or_, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -52,7 +53,7 @@ def regenerate_summary_index_task(
return
# Only regenerate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
logger.info(
click.style(
f"Skipping summary regeneration for dataset {dataset_id}: "
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
),
DatasetDocument.enabled == True, # Document must be enabled
DatasetDocument.archived == False, # Document must not be archived
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
for dataset_document in dataset_documents:
# Skip qa_model documents
if dataset_document.doc_form == "qa_model":
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
continue
try:

View File

@ -68,7 +68,7 @@ def init_tool_node(config: dict):
return node
def test_tool_variable_invoke():
def test_tool_variable_invoke(monkeypatch):
node = init_tool_node(
config={
"id": "1",
@ -103,7 +103,7 @@ def test_tool_variable_invoke():
assert item.node_run_result.outputs.get("text") is not None
def test_tool_mixed_invoke():
def test_tool_mixed_invoke(monkeypatch):
node = init_tool_node(
config={
"id": "1",

View File

@ -1,8 +1,11 @@
"""Testcontainers integration tests for email register controller endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
@ -13,14 +16,11 @@ from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
def app(flask_app_with_containers):
return flask_app_with_containers
class TestEmailRegisterSendEmailApi:
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi:
mock_is_freeze,
mock_send_mail,
mock_get_account,
mock_session_cls,
app,
):
mock_send_mail.return_value = "token-123"
mock_is_freeze.return_value = False
mock_account = MagicMock()
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = mock_account
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi:
assert response == {"result": "success", "data": "token-123"}
mock_is_freeze.assert_called_once_with("invitee@example.com")
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
mock_extract_ip.assert_called_once()
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi:
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
@ -114,7 +107,6 @@ class TestEmailRegisterResetApi:
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@ -125,7 +117,6 @@ class TestEmailRegisterResetApi:
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_create_account,
mock_login,
mock_reset_login_rate,
@ -136,14 +127,10 @@ class TestEmailRegisterResetApi:
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
@ -159,19 +146,19 @@ class TestEmailRegisterResetApi:
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
"""Test that case fallback tries lowercase when exact match fails."""
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert account is expected_account
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -1,8 +1,11 @@
"""Testcontainers integration tests for forgot password controller endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
@ -13,14 +16,11 @@ from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
def app(flask_app_with_containers):
return flask_app_with_containers
class TestForgotPasswordSendEmailApi:
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi:
mock_is_ip_limit,
mock_send_email,
mock_get_account,
mock_session_cls,
app,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
controller_features = SimpleNamespace(is_allow_register=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch(
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
return_value=controller_features,
@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi:
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=mock_account,
email="user@example.com",
@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@ -126,7 +120,6 @@ class TestForgotPasswordResetApi:
mock_get_reset_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_update_account,
app,
):
@ -134,12 +127,8 @@ class TestForgotPasswordResetApi:
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
@ -157,20 +146,22 @@ class TestForgotPasswordResetApi:
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("token-123")
mock_revoke_token.assert_called_once_with("token-123")
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
"""Test that case fallback tries lowercase when exact match fails."""
from unittest.mock import MagicMock
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
assert account is expected_account
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -1,7 +1,10 @@
"""Testcontainers integration tests for OAuth controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
class TestGetOAuthProviders:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.mark.parametrize(
("github_config", "google_config", "expected_github", "expected_google"),
@ -64,10 +65,8 @@ class TestOAuthLogin:
return OAuthLogin()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_oauth_provider(self):
@ -131,10 +130,8 @@ class TestOAuthCallback:
return OAuthCallback()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def oauth_setup(self):
@ -190,15 +187,8 @@ class TestOAuthCallback:
(KeyError("Missing key"), "OAuth process failed"),
],
)
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_handle_oauth_exceptions(
self, mock_get_providers, mock_db, resource, app, exception, expected_error
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
# Import the real requests module to create a proper exception
import httpx
@ -258,7 +248,6 @@ class TestOAuthCallback:
)
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@ -269,7 +258,6 @@ class TestOAuthCallback:
mock_generate_account,
mock_get_providers,
mock_config,
mock_db,
mock_tenant_service,
mock_account_service,
resource,
@ -278,10 +266,6 @@ class TestOAuthCallback:
account_status,
expected_redirect,
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
mock_db.session.commit = MagicMock()
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -306,14 +290,12 @@ class TestOAuthCallback:
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
def test_should_activate_pending_account(
self,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
@ -338,12 +320,10 @@ class TestOAuthCallback:
assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.redirect")
@ -352,7 +332,6 @@ class TestOAuthCallback:
mock_redirect,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
@ -414,6 +393,10 @@ class TestOAuthCallback:
class TestAccountGeneration:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def user_info(self):
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
@ -425,15 +408,10 @@ class TestAccountGeneration:
return account
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.db")
def test_should_get_account_by_openid_or_email(
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
):
# Mock db.engine for Session creation
mock_db.engine = MagicMock()
# Test OpenID found
mock_account_model.get_by_openid.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
@ -443,15 +421,14 @@ class TestAccountGeneration:
# Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
mock_get_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self):
"""Test that case fallback tries lowercase when exact match fails."""
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
@ -462,7 +439,7 @@ class TestAccountGeneration:
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert result == expected_account
assert result is expected_account
assert mock_session.execute.call_count == 2
@pytest.mark.parametrize(
@ -478,10 +455,8 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_handle_account_generation_scenarios(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
@ -519,10 +494,8 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
@ -38,7 +39,7 @@ class TestGetAvailableDatasetsIntegration:
provider="dify",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
name=f"Document {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=status, # Not completed
enabled=True,
archived=False,
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -459,7 +460,7 @@ class TestKnowledgeRetrievalIntegration:
provider="dify",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
db_session_with_containers.add(dataset)
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()

View File

@ -0,0 +1,227 @@
"""
Integration tests for Redis Streams broadcast channel implementation using TestContainers.
This suite focuses on the semantics that differ from Redis Pub/Sub:
- Every active subscription should receive each newly published message.
- Each subscription should only observe messages published after its listener starts.
"""
import threading
import time
import uuid
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
import redis
from testcontainers.redis import RedisContainer
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
class TestRedisStreamsBroadcastChannelIntegration:
"""Integration tests for Redis Streams broadcast channel with a real Redis instance."""
@pytest.fixture(scope="class")
def redis_container(self) -> Iterator[RedisContainer]:
"""Create a Redis container for integration testing."""
with RedisContainer(image="redis:6-alpine") as container:
yield container
@pytest.fixture(scope="class")
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
"""Create a Redis client connected to the test container."""
host = redis_container.get_container_host_ip()
port = redis_container.get_exposed_port(6379)
return redis.Redis(host=host, port=port, decode_responses=False)
@pytest.fixture
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
"""Create a StreamsBroadcastChannel instance with a real Redis client."""
return StreamsBroadcastChannel(redis_client)
@classmethod
def _get_test_topic_name(cls) -> str:
return f"test_streams_topic_{uuid.uuid4()}"
@staticmethod
def _start_subscription(subscription: Subscription) -> None:
"""Start the background listener and confirm the subscription queue is empty."""
assert subscription.receive(timeout=0.05) is None
@staticmethod
def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes:
"""Poll until a message is received or the timeout expires."""
deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline:
message = subscription.receive(timeout=0.1)
if message is not None:
return message
pytest.fail("Timed out waiting for a message")
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None:
"""Closing an active subscription should terminate the iterator cleanly."""
topic = broadcast_channel.topic(self._get_test_topic_name())
subscription = topic.subscribe()
consuming_event = threading.Event()
def consume() -> list[bytes]:
messages: list[bytes] = []
consuming_event.set()
for message in subscription:
messages.append(message)
return messages
with ThreadPoolExecutor(max_workers=1) as executor:
consumer_future = executor.submit(consume)
assert consuming_event.wait(timeout=1.0)
subscription.close()
assert consumer_future.result(timeout=2.0) == []
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None:
"""A producer should publish a message that a live subscription can consume."""
topic = broadcast_channel.topic(self._get_test_topic_name())
producer = topic.as_producer()
subscription = topic.subscribe()
message = b"hello streams"
try:
self._start_subscription(subscription)
producer.publish(message)
assert self._receive_message(subscription) == message
assert subscription.receive(timeout=0.1) is None
finally:
subscription.close()
def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None:
"""Each active subscription should receive the same newly published message."""
topic = broadcast_channel.topic(self._get_test_topic_name())
subscriptions = [topic.subscribe() for _ in range(3)]
new_message = b"message-visible-to-every-subscriber"
try:
for subscription in subscriptions:
self._start_subscription(subscription)
topic.publish(new_message)
for subscription in subscriptions:
assert self._receive_message(subscription) == new_message
assert subscription.receive(timeout=0.1) is None
finally:
for subscription in subscriptions:
subscription.close()
def test_each_subscription_only_receives_messages_published_after_it_starts(
self,
broadcast_channel: BroadcastChannel,
) -> None:
"""A late subscription should not replay messages that existed before its listener started."""
topic = broadcast_channel.topic(self._get_test_topic_name())
first_subscription = topic.subscribe()
second_subscription = topic.subscribe()
message_before_any_subscription = b"before-any-subscription"
message_after_first_subscription = b"after-first-subscription"
message_after_second_subscription = b"after-second-subscription"
try:
topic.publish(message_before_any_subscription)
self._start_subscription(first_subscription)
topic.publish(message_after_first_subscription)
assert self._receive_message(first_subscription) == message_after_first_subscription
assert first_subscription.receive(timeout=0.1) is None
self._start_subscription(second_subscription)
topic.publish(message_after_second_subscription)
assert self._receive_message(first_subscription) == message_after_second_subscription
assert self._receive_message(second_subscription) == message_after_second_subscription
assert first_subscription.receive(timeout=0.1) is None
assert second_subscription.receive(timeout=0.1) is None
finally:
first_subscription.close()
second_subscription.close()
def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None:
"""Messages from different topics should remain isolated."""
topic1 = broadcast_channel.topic(self._get_test_topic_name())
topic2 = broadcast_channel.topic(self._get_test_topic_name())
message1 = b"message-for-topic-1"
message2 = b"message-for-topic-2"
def consume_single_message(topic: Topic) -> bytes:
subscription = topic.subscribe()
try:
self._start_subscription(subscription)
return self._receive_message(subscription)
finally:
subscription.close()
with ThreadPoolExecutor(max_workers=3) as executor:
consumer1_future = executor.submit(consume_single_message, topic1)
consumer2_future = executor.submit(consume_single_message, topic2)
time.sleep(0.1)
topic1.publish(message1)
topic2.publish(message2)
assert consumer1_future.result(timeout=5.0) == message1
assert consumer2_future.result(timeout=5.0) == message2
def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None:
"""Concurrent producers should not lose messages for a live subscription."""
topic = broadcast_channel.topic(self._get_test_topic_name())
subscription = topic.subscribe()
producer_count = 4
messages_per_producer = 4
expected_total = producer_count * messages_per_producer
consumer_ready = threading.Event()
def produce_messages(producer_idx: int) -> set[bytes]:
producer = topic.as_producer()
produced: set[bytes] = set()
for message_idx in range(messages_per_producer):
payload = f"producer-{producer_idx}-message-{message_idx}".encode()
produced.add(payload)
producer.publish(payload)
time.sleep(0.001)
return produced
def consume_messages() -> set[bytes]:
received: set[bytes] = set()
try:
self._start_subscription(subscription)
consumer_ready.set()
while len(received) < expected_total:
message = subscription.receive(timeout=0.2)
if message is not None:
received.add(message)
return received
finally:
subscription.close()
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
consumer_future = executor.submit(consume_messages)
assert consumer_ready.wait(timeout=2.0)
producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)]
expected_messages: set[bytes] = set()
for future in as_completed(producer_futures, timeout=10.0):
expected_messages.update(future.result())
assert consumer_future.result(timeout=10.0) == expected_messages
def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None:
"""Calling receive on a closed subscription should raise SubscriptionClosedError."""
topic = broadcast_channel.topic(self._get_test_topic_name())
subscription = topic.subscribe()
self._start_subscription(subscription)
subscription.close()
with pytest.raises(SubscriptionClosedError):
subscription.receive(timeout=0.1)

View File

@ -13,6 +13,7 @@ import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.enums import DataSourceType
@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory:
name=name,
description="Test description",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=permission,
provider="vendor",

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id
document.indexing_status = indexing_status

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from models.model import App, IconType, Site
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -463,6 +463,109 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_should_preserve_icon_type_when_omitted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update_app keeps the persisted icon_type when the update payload omits it.
"""
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
account,
)
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(
app,
{
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": None,
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
},
)
assert updated_app.icon_type == IconType.EMOJI
def test_update_app_should_reject_empty_icon_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update_app rejects an explicit empty icon_type.
"""
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
account,
)
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
with pytest.raises(ValueError):
app_service.update_app(
app,
{
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": "",
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
},
)
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app name update.
@ -1142,3 +1245,51 @@ class TestAppService:
assert paginated_apps is not None
assert paginated_apps.total == 1
assert all("50%" in app.name for app in paginated_apps.items)
def test_get_app_code_by_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_code_by_id raises ValueError when site is missing."""
from uuid import uuid4
from services.app_service import AppService
with pytest.raises(ValueError, match="not found"):
AppService.get_app_code_by_id(str(uuid4()))
def test_get_app_id_by_code_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_id_by_code raises ValueError when code does not exist."""
from services.app_service import AppService
with pytest.raises(ValueError, match="not found"):
AppService.get_app_id_by_code("nonexistent-code")
def test_get_app_meta_returns_empty_when_workflow_missing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_meta returns empty tool_icons when workflow is None."""
from types import SimpleNamespace
from services.app_service import AppService
app_service = AppService()
workflow_app = SimpleNamespace(mode="workflow", workflow=None)
meta = app_service.get_app_meta(workflow_app)
assert meta == {"tool_icons": {}}
def test_get_app_meta_returns_empty_when_model_config_missing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test get_app_meta returns empty tool_icons when app_model_config is None."""
from types import SimpleNamespace
from services.app_service import AppService
app_service = AppService()
chat_app = SimpleNamespace(mode="chat", app_model_config=None)
meta = app_service.get_app_meta(chat_app)
assert meta == {"tool_icons": {}}

View File

@ -9,6 +9,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory:
name=name,
description="desc",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=permission,
provider="vendor",

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -62,7 +63,7 @@ class DatasetServiceIntegrationDataFactory:
name: str = "Test Dataset",
description: str | None = "Test description",
provider: str = "vendor",
indexing_technique: str | None = "high_quality",
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
permission: str = DatasetPermissionEnum.ONLY_ME,
retrieval_model: dict | None = None,
embedding_model_provider: str | None = None,
@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory:
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -156,13 +157,13 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="Economy Dataset",
description=None,
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
account=account,
)
# Assert
db_session_with_containers.refresh(result)
assert result.indexing_technique == "economy"
assert result.indexing_technique == IndexTechniqueType.ECONOMY
assert result.embedding_model_provider is None
assert result.embedding_model is None
@ -180,13 +181,13 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="High Quality Dataset",
description=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
account=account,
)
# Assert
db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model_name
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
@ -272,7 +273,7 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="Dataset With Reranking",
description=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
account=account,
retrieval_model=retrieval_model,
)
@ -305,7 +306,7 @@ class TestDatasetServiceCreateDataset:
tenant_id=tenant.id,
name="Custom Embedding Dataset",
description=None,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
account=account,
embedding_model_provider=embedding_provider,
embedding_model_name=embedding_model_name,
@ -313,7 +314,7 @@ class TestDatasetServiceCreateDataset:
# Assert
db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert result.embedding_model_provider == embedding_provider
assert result.embedding_model == embedding_model_name
mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name)
@ -588,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset:
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
chunk_structure="text_model",
)
DatasetServiceIntegrationDataFactory.create_document(
@ -684,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration:
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=str(uuid4()),
)
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": {
"search_method": "full_text_search",
"top_k": 10,
@ -706,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration:
db_session_with_containers.refresh(dataset)
assert result.id == dataset.id
assert dataset.retrieval_model == update_data["retrieval_model"]
class TestDocumentServicePauseRecoverRetry:
"""Tests for pause/recover/retry orchestration using real DB and Redis."""
def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"):
factory = DatasetServiceIntegrationDataFactory
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
doc = factory.create_document(db_session_with_containers, dataset, account.id)
doc.indexing_status = indexing_status
db_session_with_containers.commit()
return doc, account
def test_pause_document_success(self, db_session_with_containers):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing")
with patch("services.dataset_service.current_user") as mock_user:
mock_user.id = account.id
DocumentService.pause_document(doc)
db_session_with_containers.refresh(doc)
assert doc.is_paused is True
assert doc.paused_by == account.id
assert doc.paused_at is not None
cache_key = f"document_{doc.id}_is_paused"
assert redis_client.get(cache_key) is not None
redis_client.delete(cache_key)
def test_pause_document_invalid_status_error(self, db_session_with_containers):
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed")
with patch("services.dataset_service.current_user") as mock_user:
mock_user.id = account.id
with pytest.raises(DocumentIndexingError):
DocumentService.pause_document(doc)
def test_recover_document_success(self, db_session_with_containers):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing")
# Pause first
with patch("services.dataset_service.current_user") as mock_user:
mock_user.id = account.id
DocumentService.pause_document(doc)
# Recover
with patch("services.dataset_service.recover_document_indexing_task") as recover_task:
DocumentService.recover_document(doc)
db_session_with_containers.refresh(doc)
assert doc.is_paused is False
assert doc.paused_by is None
assert doc.paused_at is None
cache_key = f"document_{doc.id}_is_paused"
assert redis_client.get(cache_key) is None
recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id)
def test_retry_document_indexing_success(self, db_session_with_containers):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
factory = DatasetServiceIntegrationDataFactory
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt")
doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt")
doc2.position = 2
doc1.indexing_status = "error"
doc2.indexing_status = "error"
db_session_with_containers.commit()
with (
patch("services.dataset_service.current_user") as mock_user,
patch("services.dataset_service.retry_document_indexing_task") as retry_task,
):
mock_user.id = account.id
DocumentService.retry_document(dataset.id, [doc1, doc2])
db_session_with_containers.refresh(doc1)
db_session_with_containers.refresh(doc2)
assert doc1.indexing_status == "waiting"
assert doc2.indexing_status == "waiting"
# Verify redis keys were set
assert redis_client.get(f"document_{doc1.id}_is_retried") is not None
assert redis_client.get(f"document_{doc2.id}_is_retried") is not None
retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id)
# Cleanup
redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried")

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id or str(uuid4())
document.enabled = enabled

View File

@ -3,6 +3,7 @@
from unittest.mock import patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory:
tenant_id: str,
dataset_id: str,
created_by: str,
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
) -> Document:
"""Persist a document so dataset.doc_form resolves through the real document path."""
document = Document(
@ -108,7 +109,7 @@ class TestDatasetServiceDeleteDataset:
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
chunk_structure=None,
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid4()),
@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset:
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=owner.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Act
@ -207,7 +208,7 @@ class TestDatasetServiceDeleteDataset:
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
chunk_structure=None,
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid4()),

View File

@ -12,6 +12,7 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory:
name=f"Test Dataset {uuid4()}",
description="Test description",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=DatasetPermissionEnum.ONLY_ME,
provider="vendor",

View File

@ -15,6 +15,7 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory:
name=name,
description="desc",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
permission=permission,
provider="vendor",

View File

@ -4,6 +4,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory:
provider: str = "vendor",
name: str = "old_name",
description: str = "old_description",
indexing_technique: str = "high_quality",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
retrieval_model: str = "old_model",
permission: str = "only_me",
embedding_model_provider: str | None = None,
@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset:
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset:
assert dataset.name == "new_name"
assert dataset.description == "new_description"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.retrieval_model == "new_model"
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset:
update_data = {
"name": "new_name",
"description": None,
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": "new_model",
"embedding_model_provider": None,
"embedding_model": None,
@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
update_data = {
"indexing_technique": "economy",
"indexing_technique": IndexTechniqueType.ECONOMY,
"retrieval_model": "new_model",
}
@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset:
mock_task.delay.assert_called_once_with(dataset.id, "remove")
db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "economy"
assert dataset.indexing_technique == IndexTechniqueType.ECONOMY
assert dataset.embedding_model is None
assert dataset.embedding_model_provider is None
assert dataset.collection_binding_id is None
@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
embedding_model = Mock()
@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset:
binding.id = str(uuid4())
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset:
mock_task.delay.assert_called_once_with(dataset.id, "add")
db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == binding.id
@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset:
update_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"retrieval_model": "new_model",
}
@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset:
db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.indexing_technique == "high_quality"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.collection_binding_id == existing_binding_id
@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset:
binding.id = str(uuid4())
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset:
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="economy",
indexing_technique=IndexTechniqueType.ECONOMY,
)
update_data = {
"indexing_technique": "high_quality",
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",

View File

@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.expunge_all()
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
assert deleted_run is None
def test_delete_run_dry_run(self, db_session_with_containers):
"""Dry run should return success without actually deleting."""
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
run_id = run.id
deleter = ArchivedWorkflowRunDeletion(dry_run=True)
result = deleter._delete_run(run)
assert result.success is True
assert result.run_id == run_id
# Run should still exist because it's a dry run
db_session_with_containers.expire_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
def test_delete_run_exception_returns_error(self, db_session_with_containers):
"""Exception during deletion should return failure result."""
from unittest.mock import MagicMock, patch
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
deleter = ArchivedWorkflowRunDeletion(dry_run=False)
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
result = deleter._delete_run(run)
assert result.success is False
assert result.error == "Database error"
def test_delete_by_run_id_success(self, db_session_with_containers):
"""Successfully delete an archived workflow run by ID."""
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=base_time,
)
self._create_archive_log(db_session_with_containers, run=run)
run_id = run.id
deleter = ArchivedWorkflowRunDeletion()
result = deleter.delete_by_run_id(run_id)
assert result.success is True
db_session_with_containers.expunge_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is None
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
deleter = ArchivedWorkflowRunDeletion()
repo1 = deleter._get_workflow_run_repo()
repo2 = deleter._get_workflow_run_repo()
assert repo1 is repo2
assert deleter.workflow_run_repo is repo1

View File

@ -3,6 +3,7 @@ from uuid import uuid4
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -42,7 +43,7 @@ def _create_document(
name=f"doc-{uuid4()}",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = str(uuid4())
document.indexing_status = indexing_status

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -69,7 +70,7 @@ def make_document(
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
doc.id = document_id
doc.indexing_status = "completed"

View File

@ -5,6 +5,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
@ -139,7 +140,7 @@ class TestMetadataService:
name=fake.file_name(),
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType, TagType
@ -102,7 +103,7 @@ class TestTagService:
provider="vendor",
permission="only_me",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
tenant_id=tenant_id,
created_by=mock_external_service_dependencies["current_user"].id,
)

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import json
import uuid
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@ -8,14 +11,14 @@ from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowAppLogCreatedFrom
from services.account_service import AccountService, TenantService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
from services.workflow_app_service import WorkflowAppService
from services.workflow_app_service import LogView, WorkflowAppService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -1525,3 +1528,168 @@ class TestWorkflowAppService:
# Should not find tenant2's data when searching from tenant1's context
assert result_cross_tenant["total"] == 0
def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"):
service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account="nonexistent@example.com",
)
def test_get_paginate_workflow_app_logs_filters_by_account(
self, db_session_with_containers, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account)
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account=account.email,
)
assert result["total"] >= 0
assert isinstance(result["data"], list)
def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
end_user = EndUser(
tenant_id=app.tenant_id,
app_id=app.id,
type="browser",
is_anonymous=False,
session_id="session-1",
)
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
now = datetime.now(UTC)
archive_defaults = {
"workflow_id": str(uuid.uuid4()),
"run_version": "1.0.0",
"run_status": WorkflowExecutionStatus.SUCCEEDED,
"run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
"run_error": None,
"run_elapsed_time": 1.0,
"run_total_tokens": 0,
"run_total_steps": 0,
"run_created_at": now,
"run_finished_at": now,
"run_exceptions_count": 0,
"trigger_metadata": '{"type":"trigger-webhook"}',
"log_created_at": now,
"log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API,
}
archive_account = WorkflowArchiveLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_run_id=str(uuid.uuid4()),
log_id=str(uuid.uuid4()),
created_by=account.id,
created_by_role=CreatorUserRole.ACCOUNT,
**archive_defaults,
)
archive_end_user = WorkflowArchiveLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_run_id=str(uuid.uuid4()),
log_id=str(uuid.uuid4()),
created_by=end_user.id,
created_by_role=CreatorUserRole.END_USER,
**archive_defaults,
)
db_session_with_containers.add_all([archive_account, archive_end_user])
db_session_with_containers.commit()
result = service.get_paginate_workflow_archive_logs(
session=db_session_with_containers,
app_model=app,
page=1,
limit=20,
)
assert result["total"] == 2
assert len(result["data"]) == 2
account_item = next(d for d in result["data"] if d["created_by_account"] is not None)
end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None)
assert account_item["created_by_account"].id == account.id
assert end_user_item["created_by_end_user"].id == end_user.id
class TestLogView:
def test_details_and_proxy_attributes(self):
log = SimpleNamespace(id="log-1", status="succeeded")
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
assert view.details == {"trigger_metadata": {"type": "plugin"}}
assert view.status == "succeeded"
class TestHandleTriggerMetadata:
def test_returns_empty_dict_when_metadata_missing(self):
service = WorkflowAppService()
assert service.handle_trigger_metadata("tenant-1", None) == {}
def test_enriches_plugin_icons(self):
service = WorkflowAppService()
meta = {
"type": AppTriggerType.TRIGGER_PLUGIN.value,
"icon_filename": "light.png",
"icon_dark_filename": "dark.png",
}
with patch(
"services.workflow_app_service.PluginService.get_plugin_icon_url",
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
) as mock_icon:
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
assert result["icon"] == "https://cdn/light.png"
assert result["icon_dark"] == "https://cdn/dark.png"
assert mock_icon.call_count == 2
def test_non_plugin_metadata_without_icon_lookup(self):
service = WorkflowAppService()
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon:
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
mock_icon.assert_not_called()
class TestSafeJsonLoads:
@pytest.mark.parametrize(
("value", "expected"),
[
(None, None),
("", None),
('{"k":"v"}', {"k": "v"}),
("not-json", None),
({"raw": True}, {"raw": True}),
],
)
def test_handles_various_inputs(self, value, expected):
assert WorkflowAppService._safe_json_loads(value) == expected
class TestSafeParseUuid:
def test_returns_none_for_short_or_invalid_values(self):
service = WorkflowAppService()
assert service._safe_parse_uuid("short") is None
assert service._safe_parse_uuid("x" * 40) is None
def test_returns_uuid_for_valid_string(self):
service = WorkflowAppService()
raw = str(uuid.uuid4())
result = service._safe_parse_uuid(raw)
assert result is not None
assert str(result) == raw

View File

@ -1,12 +1,24 @@
from __future__ import annotations
from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolDescription,
ToolEntity,
ToolIdentity,
ToolParameter,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
@ -52,7 +64,7 @@ class TestToolTransformService:
user_id="test_user_id",
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
elif provider_type == "builtin":
@ -659,7 +671,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -695,7 +707,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -731,7 +743,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -786,3 +798,192 @@ class TestToolTransformService:
assert result is not None
assert result == mock_controller
mock_from_db.assert_called_once_with(provider)
def _mock_tool(*, base_params, runtime_params):
"""Helper to build a Mock tool with real entity objects.
Tool is abstract and requires runtime behaviour (fork_tool_runtime,
get_runtime_parameters), so it stays as a Mock. Everything else uses
real Pydantic instances.
"""
entity = ToolEntity(
identity=ToolIdentity(
author="test_author",
name="test_tool",
label=I18nObject(en_US="Test Tool"),
provider="test_provider",
),
parameters=base_params or [],
description=ToolDescription(
human=I18nObject(en_US="Test description"),
llm="Test description for LLM",
),
output_schema={},
)
mock_tool = Mock(spec=Tool)
mock_tool.entity = entity
mock_tool.get_runtime_parameters.return_value = runtime_params
mock_tool.fork_tool_runtime.return_value = mock_tool
return mock_tool
def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None):
return ToolParameter(
name=name,
label=I18nObject(en_US=label or name),
human_description=I18nObject(en_US=name),
type=ToolParameter.ToolParameterType.STRING,
form=form,
)
class TestConvertToolEntityToApiEntity:
"""Tests for ToolTransformService.convert_tool_entity_to_api_entity."""
def test_parameter_override(self):
base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")]
runtime = [_param("param1", label="Runtime 1")]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert isinstance(result, ToolApiEntity)
assert len(result.parameters) == 2
assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1"
assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2"
def test_additional_runtime_parameters(self):
base = [_param("param1", label="Base 1")]
runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert len(result.parameters) == 2
names = [p.name for p in result.parameters]
assert "param1" in names
assert "runtime_only" in names
def test_non_form_runtime_parameters_excluded(self):
base = [_param("param1")]
runtime = [
_param("param1", label="Runtime 1"),
_param("llm_param", form=ToolParameter.ToolParameterForm.LLM),
]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert len(result.parameters) == 1
assert result.parameters[0].name == "param1"
def test_empty_parameters(self):
tool = _mock_tool(base_params=[], runtime_params=[])
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert isinstance(result, ToolApiEntity)
assert len(result.parameters) == 0
def test_none_parameters(self):
tool = _mock_tool(base_params=None, runtime_params=[])
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert isinstance(result, ToolApiEntity)
assert len(result.parameters) == 0
def test_parameter_order_preserved(self):
base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")]
runtime = [_param("p2", label="R2"), _param("p4", label="R4")]
tool = _mock_tool(base_params=base, runtime_params=runtime)
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"]
assert result.parameters[1].label.en_US == "R2"
class TestWorkflowProviderToUserProvider:
"""Tests for ToolTransformService.workflow_provider_to_user_provider."""
@staticmethod
def _make_controller(provider_id="provider_123", **identity_overrides):
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
defaults = {
"author": "test_author",
"name": "test_workflow_tool",
"description": I18nObject(en_US="Test description"),
"icon": '{"type": "emoji", "content": "🔧"}',
"icon_dark": None,
"label": I18nObject(en_US="Test Workflow Tool"),
}
defaults.update(identity_overrides)
identity = ToolProviderIdentity(**defaults)
entity = ToolProviderEntity(identity=identity)
return WorkflowToolProviderController(entity=entity, provider_id=provider_id)
def test_with_workflow_app_id(self):
ctrl = self._make_controller()
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=["l1", "l2"],
workflow_app_id="app_123",
)
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "provider_123"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == "app_123"
assert result.labels == ["l1", "l2"]
assert result.is_team_authorization is True
def test_without_workflow_app_id(self):
ctrl = self._make_controller()
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=["l1"],
)
assert result.workflow_app_id is None
def test_workflow_app_id_none_explicit(self):
ctrl = self._make_controller()
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=None,
workflow_app_id=None,
)
assert result.workflow_app_id is None
assert result.labels == []
def test_preserves_other_fields(self):
ctrl = self._make_controller(
"provider_456",
author="another_author",
name="another_workflow_tool",
description=I18nObject(en_US="Another desc", zh_Hans="Another desc"),
icon='{"type": "emoji", "content": "⚙️"}',
icon_dark='{"type": "emoji", "content": "🔧"}',
label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"),
)
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=ctrl,
labels=["automation"],
workflow_app_id="app_456",
)
assert result.id == "provider_456"
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == "app_456"
assert result.is_team_authorization is True
assert result.allow_delete is True

View File

@ -4,7 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask:
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Execute the task with non-existent dataset
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
batch_clean_document_task(
document_ids=[document_id],
dataset_id=dataset_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
file_ids=[],
)
# Verify that no index processing occurred
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask:
account = self._create_test_account(db_session_with_containers)
# Test different doc_form types
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
dataset = self._create_test_dataset(db_session_with_containers, account)

View File

@ -19,6 +19,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -141,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask:
name=fake.company(),
description=fake.text(),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
created_by=account.id,
@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
)
@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask:
return upload_file
def _create_test_csv_content(self, content_type="text_model"):
def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX):
"""
Helper method to create test CSV content.
Args:
content_type: Type of content to create ("text_model" or "qa_model")
content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX)
Returns:
str: CSV content as string
"""
if content_type == "qa_model":
if content_type == IndexStructureType.QA_INDEX:
csv_content = "content,answer\n"
csv_content += "This is the first segment content,This is the first answer\n"
csv_content += "This is the second segment content,This is the second answer\n"
@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask:
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
# Mock storage to return our CSV content
mock_storage = mock_external_service_dependencies["storage"]
@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Document is disabled
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
# Archived document
@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Document is archived
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
# Document with incomplete indexing
@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.INDEXING, # Not completed
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
]
@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask:
db_session_with_containers.commit()
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
# Mock storage to return our CSV content
mock_storage = mock_external_service_dependencies["storage"]

View File

@ -18,6 +18,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -153,7 +154,7 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
name="test_dataset",
description="Test dataset for cleanup testing",
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph"}',
collection_binding_id=str(uuid.uuid4()),
created_by=account.id,
@ -192,7 +193,7 @@ class TestCleanDatasetTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=100,
created_at=datetime.now(),
updated_at=datetime.now(),
@ -869,7 +870,7 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
name=long_name,
description=long_description,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
index_struct='{"type": "paragraph", "max_length": 10000}',
collection_binding_id=str(uuid.uuid4()),
created_by=account.id,

View File

@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask:
name=f"Notion Page {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
)
@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask:
# Test different index types
# Note: Only testing text_model to avoid dependency on external services
index_types = ["text_model"]
index_types = [IndexStructureType.PARAGRAPH_INDEX]
for index_type in index_types:
# Create dataset (doc_form will be set via document creation)

View File

@ -12,6 +12,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -120,7 +121,7 @@ class TestCreateSegmentToIndexTask:
description=fake.text(max_nb_chars=100),
tenant_id=tenant_id,
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
created_by=account_id,
@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask:
enabled=True,
archived=False,
indexing_status=IndexingStatus.COMPLETED,
doc_form="qa_model",
doc_form=IndexStructureType.QA_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask:
enabled=True,
archived=False,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask:
- Processing completes successfully for different forms
"""
# Arrange: Test different doc_forms
doc_forms = ["qa_model", "text_model", "web_model"]
doc_forms = [
IndexStructureType.QA_INDEX,
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.PARAGRAPH_INDEX,
]
for doc_form in doc_forms:
# Create fresh test data for each form

View File

@ -8,6 +8,7 @@ import pytest
from faker import Faker
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration:
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="qa_index",
doc_form=IndexStructureType.QA_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask:
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with custom index type
mock_index_processor_factory.assert_called_once_with("qa_index")
mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX)
mock_factory = mock_index_processor_factory.return_value
mock_processor = mock_factory.init_index_processor.return_value
mock_processor.load.assert_called_once()
@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask:
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with the document's index type
mock_index_processor_factory.assert_called_once_with("text_model")
mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX)
mock_factory = mock_index_processor_factory.return_value
mock_processor = mock_factory.init_index_processor.return_value
mock_processor.load.assert_called_once()
@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask:
name=f"Test Document {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask:
name="Enabled Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask:
name="Disabled Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # This document should be skipped
@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask:
name="Active Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask:
name="Archived Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask:
name="Completed Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask:
name="Incomplete Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
enabled=True,

View File

@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask:
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = DataSourceType.UPLOAD_FILE
dataset.indexing_technique = "high_quality"
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.index_struct = '{"type": "paragraph"}'
dataset.created_by = account.id
dataset.created_at = fake.date_time_this_year()

View File

@ -15,6 +15,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -99,7 +100,7 @@ class TestDisableSegmentFromIndexTask:
name=fake.sentence(nb_words=3),
description=fake.text(max_nb_chars=200),
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask:
dataset: Dataset,
tenant: Tenant,
account: Account,
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
) -> Document:
"""
Helper method to create a test document.
@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask:
- Index processor clean method is called correctly
"""
# Test different document forms
doc_forms = ["text_model", "qa_model", "table_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
# Arrange: Create test data for each form

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
@ -102,7 +103,7 @@ class TestDisableSegmentsFromIndexTask:
provider="vendor",
permission="only_me",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=account.id,
updated_by=account.id,
embedding_model="text-embedding-ada-002",
@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask:
document.indexing_status = "completed"
document.enabled = True
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing
document.doc_language = "en"
db_session_with_containers.add(document)
db_session_with_containers.commit()
@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask:
segment_ids = [segment.id for segment in segments]
# Test different document forms
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
# Update document form

View File

@ -14,6 +14,7 @@ from uuid import uuid4
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -56,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
name=f"dataset-{uuid4()}",
description="sync test dataset",
data_source_type=DataSourceType.NOTION_IMPORT,
indexing_technique="high_quality",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
created_by=created_by,
indexing_status=indexing_status,
enabled=True,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
)
db_session_with_containers.add(document)

Some files were not shown because too many files have changed in this diff Show More