Merge branch 'main' into feat/end-user-oauth

# Conflicts:
#	web/app/components/app/configuration/config/agent/agent-tools/index.tsx
This commit is contained in:
zhsama 2025-12-09 17:41:01 +08:00
commit 2ea07cd8f8
350 changed files with 9871 additions and 9515 deletions

View File

@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10

View File

@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings):
default=10,
)
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "

View File

@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
content=args.content,
from_source="admin",
from_account_id=current_user.id,
)

View File

@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel):
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
@ -423,17 +424,16 @@ class DatasetApi(Resource):
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
payload_data = payload.model_dump(exclude_unset=True)
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
DatasetService.check_embedding_model_setting(
is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
payload.is_multimodal = is_multimodal
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, payload.permission, payload.partial_member_list

View File

@ -424,6 +424,10 @@ class DatasetInitApi(Resource):
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_config.embedding_model,
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."

View File

@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel):
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):

View File

@ -1,7 +1,7 @@
import logging
from typing import Any
from flask_restx import marshal
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@ -54,16 +55,28 @@ class DatasetsHitTestingBase:
def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
query=args.get("query"),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required

View File

@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from ..wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
logger = logging.getLogger(__name__)
@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource):
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""

View File

@ -83,6 +83,7 @@ class AppRunner:
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
@ -111,6 +112,7 @@ class AppRunner:
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner):
# get context from datasets
context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner):
)
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner):
memory=memory,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner):
context=context,
memory=memory,
image_detail_config=image_detail_config,
context_files=context_files,
)
# check hosting moderation

View File

@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner):
# get context from datasets
context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
)
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner):
query=query,
context=context,
image_detail_config=image_detail_config,
context_files=context_files,
)
# check hosting moderation

View File

@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
document_id,
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,

View File

@ -7,7 +7,7 @@ import time
import uuid
from typing import Any
from flask import current_app
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
@ -89,8 +90,17 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
documents = self._transform(
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
index_processor,
dataset,
text_docs,
requeried_document.doc_language,
processing_rule.to_dict(),
current_user=current_user,
)
# save segment
self._load_segments(dataset, requeried_document, documents)
@ -136,7 +146,7 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
@ -152,8 +162,17 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
documents = self._transform(
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
index_processor,
dataset,
text_docs,
requeried_document.doc_language,
processing_rule.to_dict(),
current_user=current_user,
)
# save segment
self._load_segments(dataset, requeried_document, documents)
@ -209,7 +228,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -302,6 +321,7 @@ class IndexingRunner:
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
documents = index_processor.transform(
text_docs,
current_user=None,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=tenant_id,
@ -551,7 +571,10 @@ class IndexingRunner:
indexing_start_at = time.perf_counter()
tokens = 0
create_keyword_thread = None
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
if (
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
):
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
@ -590,7 +613,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
if (
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
and dataset.indexing_technique == "economy"
and create_keyword_thread is not None
):
@ -635,7 +658,13 @@ class IndexingRunner:
db.session.commit()
def _process_chunk(
self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
self,
flask_app: Flask,
index_processor: BaseIndexProcessor,
chunk_documents: list[Document],
dataset: Dataset,
dataset_document: DatasetDocument,
embedding_model_instance: ModelInstance | None,
):
with flask_app.app_context():
# check document is paused
@ -646,8 +675,15 @@ class IndexingRunner:
page_content_list = [document.page_content for document in chunk_documents]
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
multimodal_documents = []
for document in chunk_documents:
if document.attachments and dataset.is_multimodal:
multimodal_documents.extend(document.attachments)
# load index
index_processor.load(dataset, chunk_documents, with_keywords=False)
index_processor.load(
dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False
)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).where(
@ -710,6 +746,7 @@ class IndexingRunner:
text_docs: list[Document],
doc_language: str,
process_rule: dict,
current_user: Account | None = None,
) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
@ -729,6 +766,7 @@ class IndexingRunner:
documents = index_processor.transform(
text_docs,
current_user,
embedding_model_instance=embedding_model_instance,
process_rule=process_rule,
tenant_id=dataset.tenant_id,
@ -737,14 +775,16 @@ class IndexingRunner:
return documents
def _load_segments(self, dataset, dataset_document, documents):
def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]):
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
doc_store.add_documents(
docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
)
# update document status to indexing
cur_time = naive_utc_now()

View File

@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
@ -200,7 +200,7 @@ class ModelInstance:
def invoke_text_embedding(
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke large language model
@ -212,7 +212,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
TextEmbeddingResult,
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
@ -223,6 +223,34 @@ class ModelInstance:
),
)
def invoke_multimodal_embedding(
self,
multimodel_documents: list[dict],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> EmbeddingResult:
"""
Invoke large language model
:param multimodel_documents: multimodel documents to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
input_type=input_type,
),
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
"""
Get number of tokens for text embedding
@ -276,6 +304,40 @@ class ModelInstance:
),
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user,
),
)
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
"""
Invoke moderation model
@ -461,6 +523,32 @@ class ModelManager:
model=default_model_entity.model,
)
def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool:
"""
Check if model supports vision
:param tenant_id: tenant id
:param provider: provider name
:param model: model name
:return: True if model supports vision, False otherwise
"""
model_instance = self.get_model_instance(tenant_id, provider, model_type, model)
model_type_instance = model_instance.model_type_instance
match model_type:
case ModelType.LLM:
model_type_instance = cast(LargeLanguageModel, model_type_instance)
case ModelType.TEXT_EMBEDDING:
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
case ModelType.RERANK:
model_type_instance = cast(RerankModel, model_type_instance)
case _:
raise ValueError(f"Model type {model_type} is not supported")
model_schema = model_type_instance.get_model_schema(model, model_instance.credentials)
if not model_schema:
return False
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
return False
class LBModelManager:
def __init__(

View File

@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage):
latency: float
class TextEmbeddingResult(BaseModel):
class EmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel):
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage
class FileEmbeddingResult(BaseModel):
"""
Model class for file embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@ -50,3 +50,43 @@ class RerankModel(AIModel):
)
except Exception as e:
raise self._transform_invoke_error(e)
def invoke_multimodal_rerank(
self,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -2,7 +2,7 @@ from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: list[str] | None = None,
multimodel_documents: list[dict] | None = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param files: files to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel):
try:
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if texts:
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
if multimodel_documents:
return plugin_model_manager.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
documents=multimodel_documents,
input_type=input_type,
)
raise ValueError("No texts or files provided")
except Exception as e:
raise self._transform_invoke_error(e)

View File

@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient):
credentials: dict,
texts: list[str],
input_type: str,
) -> TextEmbeddingResult:
) -> EmbeddingResult:
"""
Invoke text embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type_=TextEmbeddingResult,
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient):
raise ValueError("Failed to invoke text embedding")
def invoke_multimodal_embedding(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
"""
Invoke file embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
type_=EmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"documents": documents,
"input_type": input_type,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke file embedding")
def get_text_embedding_num_tokens(
self,
tenant_id: str,
@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient):
raise ValueError("Failed to invoke rerank")
def invoke_multimodal_rerank(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: dict,
docs: list[dict],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
"""
Invoke multimodal rerank
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "rerank",
"model": model,
"credentials": credentials,
"query": query,
"docs": docs,
"score_threshold": score_threshold,
"top_n": top_n,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke multimodal rerank")
def invoke_tts(
self,
tenant_id: str,

View File

@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
inputs = {key: str(value) for key, value in inputs.items()}
@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform):
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
)
return prompt_messages, stops
@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
prompt_messages: list[PromptMessage] = []
@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform):
)
if query:
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files))
else:
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files))
return prompt_messages, None
@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform):
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
# get prompt
prompt, prompt_rules = self._get_prompt_str_and_rules(
@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform):
if stops is not None and len(stops) == 0:
stops = None
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops
def _get_last_user_message(
self,
prompt: str,
files: Sequence["File"],
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> UserPromptMessage:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if files:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
if context_files:
for file in context_files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
)
if prompt_message_contents:
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
prompt_message = UserPromptMessage(content=prompt_message_contents)

View File

@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@ -30,9 +31,10 @@ class DataPostProcessor:
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)

View File

@ -1,23 +1,30 @@
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
@ -37,14 +44,15 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
if not query:
if not query and not attachment_ids:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset:
@ -56,69 +64,52 @@ class RetrievalService:
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
retrieval_service = RetrievalService()
if query:
futures.append(
executor.submit(
cls.keyword_search,
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
retrieval_method=retrieval_method,
dataset=dataset,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=None,
all_documents=all_documents,
exceptions=exceptions,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
if attachment_ids:
for attachment_id in attachment_ids:
futures.append(
executor.submit(
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
retrieval_method=retrieval_method,
dataset=dataset,
query=None,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=attachment_id,
all_documents=all_documents,
exceptions=exceptions,
)
)
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
@ -223,6 +214,7 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
with flask_app.app_context():
try:
@ -231,14 +223,30 @@ class RetrievalService:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
documents = []
if query_type == QueryType.TEXT_QUERY:
documents.extend(
vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if query_type == QueryType.IMAGE_QUERY:
if not dataset.is_multimodal:
return
documents.extend(
vector.search_by_file(
file_id=query,
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if documents:
if (
@ -250,14 +258,37 @@ class RetrievalService:
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
if dataset.is_multimodal:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
model_type=ModelType.RERANK,
)
if is_support_vision:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
else:
# not effective, return original documents
all_documents.extend(documents)
else:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
)
else:
all_documents.extend(documents)
except Exception as e:
@ -339,103 +370,159 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
segment_file_map = {}
with Session(db.engine) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
)
)
.first()
)
.first()
)
if not segment:
continue
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
if attachment_info:
segment_file_map[segment.id].append(attachment_info)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = db.session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
records.append(record)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
result = []
for record in records:
@ -447,6 +534,11 @@ class RetrievalService:
if not isinstance(child_chunks, list):
child_chunks = None
# Extract files, ensuring it's a list or None
files = record.get("files")
if not isinstance(files, list):
files = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
@ -456,10 +548,149 @@ class RetrievalService:
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e
def _retrieve(
self,
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
with flask_app.app_context():
all_documents_item: list[Document] = []
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
futures.append(
executor.submit(
self.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
all_documents=all_documents_item,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
if query:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.TEXT_QUERY,
)
)
if attachment_id:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=attachment_id,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.IMAGE_QUERY,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
futures.append(
executor.submit(
self.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE:
all_documents.extend(all_documents_item)
all_documents_item = self._deduplicate_documents(all_documents_item)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
query = query or attachment_id
if not query:
return
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
)
all_documents.extend(all_documents_item)
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first()
)
if attachment_binding:
attchment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
return None

View File

@ -1,3 +1,4 @@
import base64
import logging
import time
from abc import ABC, abstractmethod
@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Dataset, Whitelist
from models.model import UploadFile
logger = logging.getLogger(__name__)
@ -203,6 +207,47 @@ class Vector:
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = len(file_documents) + batch_size - 1
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list = []
real_batch = []
for document in batch:
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
if upload_file:
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
file_base64_list.append(
{
"content": file_base64_str,
"content_type": doc_type,
"file_id": attachment_id,
}
)
real_batch.append(document)
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@ -223,6 +268,22 @@ class Vector:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return []
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
multimodal_vector = self._embeddings.embed_multimodal_query(
{
"content": file_base64_str,
"content_type": DocType.IMAGE,
"file_id": file_id,
}
)
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)

View File

@ -79,6 +79,18 @@ class WeaviateVector(BaseVector):
self._client = self._init_client(config)
self._attributes = attributes
def __del__(self):
"""
Destructor to properly close the Weaviate client connection.
Prevents connection leaks and resource warnings.
"""
if hasattr(self, "_client") and self._client is not None:
try:
self._client.close()
except Exception as e:
# Ignore errors during cleanup as object is being destroyed
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
"""
Initializes and returns a connected Weaviate client.

View File

@ -5,9 +5,9 @@ from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
class DatasetDocumentStore:
@ -120,6 +120,9 @@ class DatasetDocumentStore:
db.session.add(segment_document)
db.session.flush()
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
@ -144,6 +147,9 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).where(
@ -233,3 +239,15 @@ class DatasetDocumentStore:
document_segment = db.session.scalar(stmt)
return document_segment
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
if multimodel_documents:
for multimodel_document in multimodel_documents:
binding = SegmentAttachmentBinding(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_id,
attachment_id=multimodel_document.metadata["doc_id"],
)
db.session.add(binding)

View File

@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"]
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
multimodel_embeddings[i] = embedding.get_embedding()
else:
embedding_queue_indices.append(i)
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
if embedding_queue_indices:
embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices]
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_multimodel_documents), max_chunks):
batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks]
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
for vector in embedding_result.embeddings:
try:
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
except Exception:
logger.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
multimodel_embeddings[i] = n_embedding
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
)
embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(file_id)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents")
raise ex
return multimodel_embeddings
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings):
raise ex
return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"])
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logger.exception(
"Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"]
)
raise ex
return embedding_results # type: ignore

View File

@ -9,11 +9,21 @@ class Embeddings(ABC):
"""Embed search docs."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
raise NotImplementedError
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal query."""
raise NotImplementedError
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError

View File

@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None

View File

@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel):
page: int | None = None
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None

View File

@ -0,0 +1,6 @@
from enum import StrEnum
class DocType(StrEnum):
TEXT = "text"
IMAGE = "image"

View File

@ -1,7 +1,12 @@
from enum import StrEnum
class IndexType(StrEnum):
class IndexStructureType(StrEnum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "hierarchical_model"
class IndexTechniqueType(StrEnum):
ECONOMY = "economy"
HIGH_QUALITY = "high_quality"

View File

@ -0,0 +1,6 @@
from enum import StrEnum
class QueryType(StrEnum):
TEXT_QUERY = "text_query"
IMAGE_QUERY = "image_query"

View File

@ -1,20 +1,34 @@
"""Abstract interface for document loader implementations."""
import cgi
import logging
import mimetypes
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import Account, ToolFile
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
if TYPE_CHECKING:
from core.model_manager import ModelInstance
@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
raise NotImplementedError
@abstractmethod
@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC):
)
return character_splitter # type: ignore
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
Get the content files from the document.
"""
multi_model_documents: list[AttachmentDocument] = []
text = document.page_content
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
if current_user:
tool_file_id = match.group(1)
upload_file_id = self._download_tool_file(tool_file_id, current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
continue
if current_user:
upload_file_id = self._download_image(image.split(" ")[0], current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
if not upload_file_id_list:
return multi_model_documents
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
# Create a Document for each occurrence (including duplicates)
for upload_file_id in upload_file_id_list:
upload_file = upload_file_map.get(upload_file_id)
if upload_file:
multi_model_documents.append(
AttachmentDocument(
page_content=upload_file.name,
metadata={
"doc_id": upload_file.id,
"doc_hash": "",
"document_id": document.metadata.get("document_id"),
"dataset_id": document.metadata.get("dataset_id"),
"doc_type": DocType.IMAGE,
},
)
)
return multi_model_documents
def _extract_markdown_images(self, text: str) -> list[str]:
"""
Extract the markdown images from the text.
"""
pattern = r"!\[.*?\]\((.*?)\)"
return re.findall(pattern, text)
def _download_image(self, image_url: str, current_user: Account) -> str | None:
"""
Download the image from the URL.
Image size must not exceed 2MB.
"""
from services.file_service import FileService
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
try:
# Download with timeout
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
response.raise_for_status()
# Check Content-Length header if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
return None
filename = None
content_disposition = response.headers.get("content-disposition")
if content_disposition:
_, params = cgi.parse_header(content_disposition)
if "filename" in params:
filename = params["filename"]
filename = unquote(filename)
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
path = unquote(parsed_url.path)
filename = os.path.basename(path)
if not filename:
filename = "downloaded_image_file"
name, current_ext = os.path.splitext(filename)
content_type = response.headers.get("content-type", "").split(";")[0].strip()
real_ext = mimetypes.guess_extension(content_type)
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
filename = f"{name}{real_ext}"
# Download content with size limit
blob = b""
for chunk in response.iter_bytes(chunk_size=8192):
blob += chunk
if len(blob) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
return None
if not blob:
logging.warning("Image from %s is empty", image_url)
return None
upload_file = FileService(db.engine).upload_file(
filename=filename,
content=blob,
mimetype=content_type,
user=current_user,
)
return upload_file.id
except httpx.TimeoutException:
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
return None
except httpx.RequestError as e:
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.exception("Unexpected error downloading image from %s", image_url)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
"""
Download the tool file from the ID.
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
upload_file = FileService(db.engine).upload_file(
filename=tool_file.name,
content=blob,
mimetype=tool_file.mimetype,
user=current_user,
)
return upload_file.id

View File

@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
@ -19,11 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX:
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX:
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
multimodal_documents = (
self._get_content_files(document_node, current_user) if document_node.metadata else None
)
if multimodal_documents:
document_node.attachments = multimodal_documents
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
documents = []
for content in chunks:
metadata = {
"dataset_id": dataset.id,
@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
attachments = self._get_content_files(doc)
if attachments:
doc.attachments = attachments
all_multimodal_documents.extend(attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
else:
raise ValueError("Chunks is not a list")
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
for general_chunk in multimodal_general_structure.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(general_chunk.content),
}
doc = Document(page_content=general_chunk.content, metadata=metadata)
if general_chunk.files:
attachments = []
for file in general_chunk.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(
page_content=file.filename or "image_file", metadata=file_metadata
)
attachments.append(file_document)
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
return {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
else:
raise ValueError("Chunks is not a list")

View File

@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
multimodel_documents = self._get_content_files(document_node, current_user)
if multimodel_documents:
document_node.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
multimodel_documents = self._get_content_files(document)
if multimodel_documents:
document.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
if parent_child.files and len(parent_child.files) > 0:
attachments = []
for file in parent_child.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
documents.append(doc)
if documents:
# update document parent mode
@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
all_child_documents = []
all_multimodal_documents = []
for doc in documents:
if doc.children:
all_child_documents.extend(doc.children)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
vector = Vector(dataset)
if all_child_documents:
vector = Vector(dataset)
vector.create(all_child_documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),

View File

@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor):
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
try:
# Skip the first row
df = pd.read_csv(file)
df = pd.read_csv(file) # type: ignore
text_docs = []
for _, row in df.iterrows():
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
"chunk_structure": IndexType.QA_INDEX,
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}

View File

@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel, Field
from core.file import File
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
@ -15,7 +17,19 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
class AttachmentDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
provider: str | None = "dify"
vector: list[float] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class Document(BaseModel):
@ -28,12 +42,31 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
provider: str | None = "dify"
children: list[ChildDocument] | None = None
attachments: list[AttachmentDocument] | None = None
class GeneralChunk(BaseModel):
"""
General Chunk.
"""
content: str
files: list[File] | None = None
class MultimodalGeneralStructureChunk(BaseModel):
"""
Multimodal General Structure Chunk.
"""
general_chunks: list[GeneralChunk]
class GeneralStructureChunk(BaseModel):
"""
@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel):
parent_content: str
child_contents: list[str]
files: list[File] | None = None
class ParentChildStructureChunk(BaseModel):

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
@ -12,6 +13,7 @@ class BaseRerankRunner(ABC):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model

View File

@ -1,6 +1,15 @@
from core.model_manager import ModelInstance
import base64
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_base import BaseRerankRunner
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
class RerankModelRunner(BaseRerankRunner):
@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model,
model_type=ModelType.RERANK,
)
if not is_support_vision:
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
else:
return documents
else:
rerank_result, unique_documents = self.fetch_multimodal_rerank(
query, documents, score_threshold, top_n, user, query_type
)
rerank_documents = []
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=unique_documents[result.index].metadata,
provider=unique_documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def fetch_text_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch text rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
docs = []
doc_ids = set()
unique_documents = []
@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
return rerank_result, unique_documents
rerank_documents = []
def fetch_multimodal_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch multimodal rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:param query_type: query type
:return: rerank result
"""
docs = []
doc_ids = set()
unique_documents = []
for document in documents:
if (
document.provider == "dify"
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access
upload_file = (
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
)
unique_documents.append(document)
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
documents = unique_documents
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_result, unique_documents
else:
raise ValueError(f"Upload file not found for query: {query}")
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
else:
raise ValueError(f"Query type {query_type} is not supported")

View File

@ -7,6 +7,8 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
# weight rerank only support text documents
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
else:
if document not in unique_documents:
unique_documents.append(document)

View File

@ -8,6 +8,7 @@ from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@ -99,7 +105,8 @@ class DatasetRetrieval:
message_id: str,
memory: TokenBufferMemory | None = None,
inputs: Mapping[str, Any] | None = None,
) -> str | None:
vision_enabled: bool = False,
) -> tuple[str | None, list[File] | None]:
"""
Retrieve dataset.
:param app_id: app_id
@ -118,7 +125,7 @@ class DatasetRetrieval:
"""
dataset_ids = config.dataset_ids
if len(dataset_ids) == 0:
return None
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
@ -136,7 +143,7 @@ class DatasetRetrieval:
)
if not model_schema:
return None
return None, []
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
@ -182,8 +189,8 @@ class DatasetRetrieval:
tenant_id,
user_id,
user_from,
available_datasets,
query,
available_datasets,
model_instance,
model_config,
planning_strategy,
@ -213,6 +220,7 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list: list[DocumentContext] = []
context_files: list[File] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
@ -248,6 +256,31 @@ class DatasetRetrieval:
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment.id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
if show_retrieve_source:
for record in records:
segment = record.segment
@ -288,8 +321,10 @@ class DatasetRetrieval:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""
return str(
"\n".join([document_context.content for document_context in document_context_list])
), context_files
return "", context_files
def single_retrieve(
self,
@ -297,8 +332,8 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
available_datasets: list,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@ -336,7 +371,7 @@ class DatasetRetrieval:
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
timer = None
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@ -406,10 +441,19 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrieval_end(results, message_id, timer)
thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": results,
"message_id": message_id,
"timer": timer,
},
)
thread.start()
return results
return []
@ -421,7 +465,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
available_datasets: list,
query: str,
query: str | None,
top_k: int,
score_threshold: float,
reranking_mode: str,
@ -431,10 +475,11 @@ class DatasetRetrieval:
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
if not available_datasets:
return []
threads = []
all_threads = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
@ -467,131 +512,226 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
},
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id)
all_threads.append(query_thread)
query_thread.start()
if attachment_ids:
for attachment_id in attachment_ids:
attachment_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrieval_end(all_documents, message_id, timer)
return all_documents
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
# add thread to call _on_retrieval_end
retrieval_end_thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": all_documents,
"message_id": message_id,
"timer": timer,
},
)
retrieval_end_thread.start()
retrieval_resource_list = []
doc_ids_filter = []
for document in all_documents:
if document.provider == "dify":
doc_id = document.metadata.get("doc_id")
if doc_id and doc_id not in doc_ids_filter:
doc_ids_filter.append(doc_id)
retrieval_resource_list.append(document)
elif document.provider == "external":
retrieval_resource_list.append(document)
return retrieval_resource_list
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
):
"""Handle retrieval end."""
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
segment_ids = []
segment_index_node_ids = []
with Session(db.engine) as session:
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
segment_id = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = session.scalar(child_chunk_stmt)
if child_chunk:
segment_id = child_chunk.segment_id
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id:
if segment_id not in segment_ids:
segment_ids.append(segment_id)
_ = (
session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
if document.metadata["doc_id"] not in segment_index_node_ids:
segment = (
session.query(DocumentSegment)
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
.first()
)
if segment:
segment_index_node_ids.append(document.metadata["doc_id"])
segment_ids.append(segment.id)
query = session.query(DocumentSegment).where(
DocumentSegment.id == segment.id
)
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id not in segment_ids:
segment_ids.append(segment_id)
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
if query:
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(
DocumentSegment.dataset_id == document.metadata["dataset_id"]
)
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
def _on_query(
self,
query: str | None,
attachment_ids: list[str] | None,
dataset_ids: list[str],
app_id: str,
user_from: str,
user_id: str,
):
"""
Handle query.
"""
if not query:
if not query and not attachment_ids:
return
dataset_queries = []
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
contents = []
if query:
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
if attachment_ids:
for attachment_id in attachment_ids:
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
if contents:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(
@ -603,6 +743,7 @@ class DatasetRetrieval:
all_documents: list,
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
with flask_app.app_context():
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@ -611,7 +752,7 @@ class DatasetRetrieval:
if not dataset:
return []
if dataset.provider == "external":
if dataset.provider == "external" and query:
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
@ -663,6 +804,7 @@ class DatasetRetrieval:
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
attachment_ids=attachment_ids,
)
all_documents.extend(documents)
@ -1222,3 +1364,86 @@ class DatasetRetrieval:
usage = LLMUsage.empty_usage()
return full_text, usage
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
top_k: int,
score_threshold: float,
query: str | None,
attachment_id: str | None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)

View File

@ -0,0 +1,65 @@
{
"$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"version": "1.0.0",
"type": "array",
"title": "Multimodal General Structure",
"description": "Schema for multimodal general structure (v1) - array of objects",
"properties": {
"general_chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The content"
},
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "file name"
},
"size": {
"type": "number",
"description": "file size"
},
"extension": {
"type": "string",
"description": "file extension"
},
"type": {
"type": "string",
"description": "file type"
},
"mime_type": {
"type": "string",
"description": "file mime type"
},
"transfer_method": {
"type": "string",
"description": "file transfer method"
},
"url": {
"type": "string",
"description": "file url"
},
"related_id": {
"type": "string",
"description": "file related id"
}
},
"description": "List of files"
}
}
},
"required": ["content"]
},
"description": "List of content and files"
}
}
}

View File

@ -0,0 +1,78 @@
{
"$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"version": "1.0.0",
"type": "object",
"title": "Multimodal Parent-Child Structure",
"description": "Schema for multimodal parent-child structure (v1)",
"properties": {
"parent_mode": {
"type": "string",
"description": "The mode of parent-child relationship"
},
"parent_child_chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"parent_content": {
"type": "string",
"description": "The parent content"
},
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "file name"
},
"size": {
"type": "number",
"description": "file size"
},
"extension": {
"type": "string",
"description": "file extension"
},
"type": {
"type": "string",
"description": "file type"
},
"mime_type": {
"type": "string",
"description": "file mime type"
},
"transfer_method": {
"type": "string",
"description": "file transfer method"
},
"url": {
"type": "string",
"description": "file url"
},
"related_id": {
"type": "string",
"description": "file related id"
}
},
"required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"]
},
"description": "List of files"
},
"child_contents": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of child contents"
}
},
"required": ["parent_content", "child_contents"]
},
"description": "List of parent-child chunk pairs"
}
},
"required": ["parent_mode", "parent_child_chunks"]
}

View File

@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str:
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def sign_upload_file(upload_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature

View File

@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@ -3,6 +3,7 @@ from datetime import datetime
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.pause_reason import PauseReason
@ -14,6 +15,7 @@ from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
context_files: list[File] | None = Field(default=None, description="context files")
class ModelInvokeCompletedEvent(NodeEventBase):

View File

@ -3,6 +3,7 @@ from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal
import charset_normalizer
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -96,10 +97,12 @@ class HttpRequestNodeData(BaseNodeData):
class Response:
headers: dict[str, str]
response: httpx.Response
_cached_text: str | None
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers)
self._cached_text = None
@property
def is_file(self):
@ -159,7 +162,31 @@ class Response:
@property
def text(self) -> str:
return self.response.text
"""
Get response text with robust encoding detection.
Uses charset_normalizer for better encoding detection than httpx's default,
which helps handle Chinese and other non-ASCII characters properly.
"""
# Check cache first
if hasattr(self, "_cached_text") and self._cached_text is not None:
return self._cached_text
# Try charset_normalizer for robust encoding detection first
detected_encoding = charset_normalizer.from_bytes(self.response.content).best()
if detected_encoding and detected_encoding.encoding:
try:
text = self.response.content.decode(detected_encoding.encoding)
self._cached_text = text
return text
except (UnicodeDecodeError, TypeError, LookupError):
# Fallback to httpx's encoding detection if charset_normalizer fails
pass
# Fallback to httpx's built-in encoding detection
text = self.response.text
self._cached_text = text
return text
@property
def content(self) -> bytes:

View File

@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
"""
type: str = "knowledge-retrieval"
query_variable_selector: list[str]
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: MultipleRetrievalConfig | None = None

View File

@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
ArrayFileSegment,
FileSegment,
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return "1"
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables = {"query": query}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
process_data={},
outputs={},
metadata={},
llm_usage=LLMUsage.empty_usage(),
)
variables: dict[str, Any] = {}
# extract variables
if self._node_data.query_variable_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables["query"] = query
if self._node_data.query_attachment_selector:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Attachments variable is not array file or file type.",
)
if isinstance(variable, ArrayFileSegment):
variables["attachments"] = variable.value
else:
variables["attachments"] = [variable.value]
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
metadata_filter_document_ids = None
metadata_condition = None
metadata_usage = LLMUsage.empty_usage()
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
if query:
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = {
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
"doc_metadata": document.doc_metadata,
},
"title": document.name,
"files": list(record.files) if record.files else None,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
key=self._score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
item["metadata"]["position"] = position # type: ignore[index]
return retrieval_resource_list, usage
def _score(self, item: dict[str, Any]) -> float:
meta = item.get("metadata")
if isinstance(meta, dict):
s = meta.get("score")
if isinstance(s, (int, float)):
return float(s)
return 0.0
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:

View File

@ -7,8 +7,10 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArraySegment,
@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]):
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
context = None
context_files: list[File] = []
for event in generator:
context = event.context
context_files = event.context_files or []
yield event
if context:
node_inputs["#context#"] = context
if context_files:
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
@ -654,10 +666,13 @@ class LLMNode(Node[LLMNodeData]):
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
yield RunRetrieverResourceEvent(
retriever_resources=[], context=context_value_variable.value, context_files=[]
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[RetrievalSourceMetadata] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@ -670,9 +685,34 @@ class LLMNode(Node[LLMNodeData]):
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource, context=context_str.strip()
retriever_resources=original_retriever_resource,
context=context_str.strip(),
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
@ -700,6 +740,7 @@ class LLMNode(Node[LLMNodeData]):
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
)
return source
@ -741,6 +782,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@ -853,6 +895,23 @@ class LLMNode(Node[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:

View File

@ -97,11 +97,27 @@ dataset_detail_fields = {
"total_documents": fields.Integer,
"total_available_documents": fields.Integer,
"enable_api": fields.Boolean,
"is_multimodal": fields.Boolean,
}
file_info_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
content_fields = {
"content_type": fields.String,
"content": fields.String,
"file_info": fields.Nested(file_info_fields, allow_null=True),
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"queries": fields.Nested(content_fields),
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,

View File

@ -9,6 +9,8 @@ upload_config_fields = {
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
"image_file_batch_limit": fields.Integer,
"single_chunk_attachment_limit": fields.Integer,
}

View File

@ -43,9 +43,19 @@ child_chunk_fields = {
"score": fields.Float,
}
files_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
hit_testing_record_fields = {
"segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
}

View File

@ -13,6 +13,15 @@ child_chunk_fields = {
"updated_at": TimestampField,
}
attachment_fields = {
"id": fields.String,
"name": fields.String,
"size": fields.Integer,
"extension": fields.String,
"mime_type": fields.String,
"source_url": fields.String,
}
segment_fields = {
"id": fields.String,
"position": fields.Integer,
@ -39,4 +48,5 @@ segment_fields = {
"error": fields.String,
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
}

View File

@ -0,0 +1,57 @@
"""support-multi-modal
Revision ID: d57accd375ae
Revises: 03f8dcbc611e
Create Date: 2025-11-12 15:37:12.363670
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd57accd375ae'
down_revision = '7bb281b7a422'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('segment_attachment_bindings',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('attachment_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey')
)
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
batch_op.create_index(
'segment_attachment_binding_tenant_dataset_document_segment_idx',
['tenant_id', 'dataset_id', 'document_id', 'segment_id'],
unique=False
)
batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('is_multimodal')
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
batch_op.drop_index('segment_attachment_binding_attachment_idx')
batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx')
op.drop_table('segment_attachment_bindings')
# ### end Alembic commands ###

View File

@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_storage import storage
from libs.uuid_utils import uuidv7
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -76,6 +78,7 @@ class Dataset(Base):
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
@property
def total_documents(self):
@ -728,9 +731,7 @@ class DocumentSegment(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
@ -866,6 +867,47 @@ class DocumentSegment(Base):
return text
@property
def attachments(self) -> list[dict[str, Any]]:
# Use JOIN to fetch attachments in a single query instead of two separate queries
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == self.tenant_id,
SegmentAttachmentBinding.dataset_id == self.dataset_id,
SegmentAttachmentBinding.document_id == self.document_id,
SegmentAttachmentBinding.segment_id == self.id,
)
).all()
if not attachments_with_bindings:
return []
attachment_list = []
for _, attachment in attachments_with_bindings:
upload_file_id = attachment.id
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
reference_url = dify_config.CONSOLE_API_URL or ""
base_url = f"{reference_url}/files/{upload_file_id}/image-preview"
source_url = f"{base_url}?{params}"
attachment_list.append(
{
"id": attachment.id,
"name": attachment.name,
"size": attachment.size,
"extension": attachment.extension,
"mime_type": attachment.mime_type,
"source_url": source_url,
}
)
return attachment_list
class ChildChunk(Base):
__tablename__ = "child_chunks"
@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase):
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
@property
def queries(self) -> list[dict[str, Any]]:
try:
queries = json.loads(self.content)
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
if file_info:
query["file_info"] = {
"id": file_info.id,
"name": file_info.name,
"size": file_info.size,
"extension": file_info.extension,
"mime_type": file_info.mime_type,
"source_url": sign_upload_file(file_info.id, file_info.extension),
}
else:
query["file_info"] = None
return queries
else:
return [queries]
except JSONDecodeError:
return [
{
"content_type": QueryType.TEXT_QUERY,
"content": self.content,
"file_info": None,
}
]
class DatasetKeywordTable(TypeBase):
__tablename__ = "dataset_keyword_tables"
@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase):
onupdate=func.current_timestamp(),
init=False,
)
class SegmentAttachmentBinding(Base):
__tablename__ = "segment_attachment_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
sa.Index(
"segment_attachment_binding_tenant_dataset_document_segment_idx",
"tenant_id",
"dataset_id",
"document_id",
"segment_id",
),
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -0,0 +1,31 @@
import base64
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from extensions.ext_storage import storage
from models.model import UploadFile
PREVIEW_WORDS_LIMIT = 3000
class AttachmentService:
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()

View File

@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, cast
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
@ -19,9 +19,10 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from events.dataset_event import dataset_was_deleted
@ -46,12 +47,14 @@ from models.dataset import (
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@ -82,7 +85,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
@ -363,6 +365,27 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
try:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=model,
)
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
else:
return False
except LLMBadRequestError:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
@ -402,13 +425,13 @@ class DatasetService:
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
if data.get("name") and data.get("name") != dataset.name:
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
@ -844,6 +867,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -880,6 +909,12 @@ class DatasetService:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
except LLMBadRequestError:
@ -937,6 +972,12 @@ class DatasetService:
)
)
dataset.collection_binding_id = dataset_collection_binding.id
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@ -1761,7 +1802,9 @@ class DocumentService:
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
DuplicateDocumentIndexingTaskProxy(
dataset.tenant_id, dataset.id, duplicate_document_ids
).delay()
except LockNotOwnedError:
pass
@ -2303,6 +2346,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
is_multimodal=knowledge_config.is_multimodal,
)
db.session.add(dataset)
@ -2683,6 +2727,13 @@ class SegmentService:
if "content" not in args or not args["content"] or not args["content"].strip():
raise ValueError("Content is empty")
if args.get("attachment_ids"):
if not isinstance(args["attachment_ids"], list):
raise ValueError("Attachment IDs is invalid")
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
@ -2729,11 +2780,23 @@ class SegmentService:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
if args["attachment_ids"]:
for attachment_id in args["attachment_ids"]:
binding = SegmentAttachmentBinding(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
segment_id=segment_document.id,
attachment_id=attachment_id,
)
db.session.add(binding)
db.session.commit()
# save vector index
@ -2897,7 +2960,7 @@ class SegmentService:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
@ -2924,12 +2987,11 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
@ -2974,7 +3036,7 @@ class SegmentService:
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
@ -3000,15 +3062,15 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
logger.exception("update segment index failed")
segment.enabled = False
@ -3046,7 +3108,9 @@ class SegmentService:
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
)
db.session.delete(segment)
# update document word count
@ -3095,7 +3159,9 @@ class SegmentService:
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
)
if document.word_count is None:
document.word_count = 0

View File

@ -29,8 +29,14 @@ def get_current_user():
from models.account import Account
from models.model import EndUser
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
try:
user_object = current_user._get_current_object()
except AttributeError:
# Handle case where current_user might not be a LocalProxy in test environments
user_object = current_user
if not isinstance(user_object, (Account, EndUser)):
raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}")
return current_user

View File

@ -0,0 +1,11 @@
from .base import DocumentTaskProxyBase
from .batch_indexing_base import BatchDocumentIndexingProxy
from .document_indexing_task_proxy import DocumentIndexingTaskProxy
from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
__all__ = [
"BatchDocumentIndexingProxy",
"DocumentIndexingTaskProxy",
"DocumentTaskProxyBase",
"DuplicateDocumentIndexingTaskProxy",
]

View File

@ -0,0 +1,111 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import cached_property
from typing import Any, ClassVar
from enums.cloud_plan import CloudPlan
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
class DocumentTaskProxyBase(ABC):
"""
Base proxy for all document processing tasks.
Handles common logic:
- Feature/billing checks
- Dispatch routing based on plan
Subclasses must define:
- QUEUE_NAME: Redis queue identifier
- NORMAL_TASK_FUNC: Task function for normal priority
- PRIORITY_TASK_FUNC: Task function for high priority
"""
QUEUE_NAME: ClassVar[str]
NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]]
PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]]
def __init__(self, tenant_id: str, dataset_id: str):
"""
Initialize with minimal required parameters.
Args:
tenant_id: Tenant identifier for billing/features
dataset_id: Dataset identifier for logging
"""
self._tenant_id = tenant_id
self._dataset_id = dataset_id
@cached_property
def features(self):
return FeatureService.get_features(self._tenant_id)
@abstractmethod
def _send_to_direct_queue(self, task_func: Callable[..., Any]):
"""
Send task directly to Celery queue without tenant isolation.
Subclasses implement this to pass task-specific parameters.
Args:
task_func: The Celery task function to call
"""
pass
@abstractmethod
def _send_to_tenant_queue(self, task_func: Callable[..., Any]):
"""
Send task to tenant-isolated queue.
Subclasses implement this to handle queue management.
Args:
task_func: The Celery task function to call
"""
pass
def _send_to_default_tenant_queue(self):
"""Route to normal priority with tenant isolation."""
self._send_to_tenant_queue(self.NORMAL_TASK_FUNC)
def _send_to_priority_tenant_queue(self):
"""Route to priority queue with tenant isolation."""
self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC)
def _send_to_priority_direct_queue(self):
"""Route to priority queue without tenant isolation."""
self._send_to_direct_queue(self.PRIORITY_TASK_FUNC)
def _dispatch(self):
"""
Dispatch task based on billing plan.
Routing logic:
- Sandbox plan normal queue + tenant isolation
- Paid plans priority queue + tenant isolation
- Self-hosted priority queue, no isolation
"""
logger.info(
"dispatch args: %s - %s - %s",
self._tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
"""Public API: Queue the task asynchronously."""
self._dispatch()

View File

@ -0,0 +1,76 @@
import logging
from collections.abc import Callable, Sequence
from dataclasses import asdict
from typing import Any
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from .base import DocumentTaskProxyBase
logger = logging.getLogger(__name__)
class BatchDocumentIndexingProxy(DocumentTaskProxyBase):
"""
Base proxy for batch document indexing tasks (document_ids in plural).
Adds:
- Tenant isolated queue management
- Batch document handling
"""
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Initialize with batch documents.
Args:
tenant_id: Tenant identifier
dataset_id: Dataset identifier
document_ids: List of document IDs to process
"""
super().__init__(tenant_id, dataset_id)
self._document_ids = document_ids
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME)
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]):
"""
Send batch task to direct queue.
Args:
task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids)
"""
logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids)
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]):
"""
Send batch task to tenant-isolated queue.
Args:
task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids)
"""
logger.info(
"tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME
)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks(
[
asdict(
DocumentTask(
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
)
]
)
logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids)

View File

@ -0,0 +1,12 @@
from typing import ClassVar
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
"""Proxy for document indexing tasks."""
QUEUE_NAME: ClassVar[str] = "document_indexing"
NORMAL_TASK_FUNC = normal_document_indexing_task
PRIORITY_TASK_FUNC = priority_document_indexing_task

View File

@ -0,0 +1,15 @@
from typing import ClassVar
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
from tasks.duplicate_document_indexing_task import (
normal_duplicate_document_indexing_task,
priority_duplicate_document_indexing_task,
)
class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
"""Proxy for duplicate document indexing tasks."""
QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing"
NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task
PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task

View File

@ -1,83 +0,0 @@
import logging
from collections.abc import Callable, Sequence
from dataclasses import asdict
from functools import cached_property
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.feature_service import FeatureService
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
logger = logging.getLogger(__name__)
class DocumentIndexingTaskProxy:
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
self._tenant_id = tenant_id
self._dataset_id = dataset_id
self._document_ids = document_ids
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
@cached_property
def features(self):
return FeatureService.get_features(self._tenant_id)
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to direct queue", self._dataset_id)
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to tenant queue", self._dataset_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks(
[
asdict(
DocumentTask(
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
)
]
)
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
def _send_to_default_tenant_queue(self):
self._send_to_tenant_queue(normal_document_indexing_task)
def _send_to_priority_tenant_queue(self):
self._send_to_tenant_queue(priority_document_indexing_task)
def _send_to_priority_direct_queue(self):
self._send_to_direct_queue(priority_document_indexing_task)
def _dispatch(self):
logger.info(
"dispatch args: %s - %s - %s",
self._tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
self._dispatch()

View File

@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
is_multimodal: bool = False
class SegmentCreateArgs(BaseModel):
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdateArgs(BaseModel):
@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
class ChildChunkUpdateArgs(BaseModel):

View File

@ -324,4 +324,5 @@ class ExternalDatasetService:
)
if response.status_code == 200:
return cast(list[Any], response.json().get("records", []))
return []
else:
raise ValueError(response.text)

View File

@ -1,3 +1,4 @@
import base64
import hashlib
import os
import uuid
@ -123,6 +124,15 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]

View File

@ -1,3 +1,4 @@
import json
import logging
import time
from typing import Any
@ -5,6 +6,7 @@ from typing import Any
from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -32,6 +34,7 @@ class HitTestingService:
account: Account,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
attachment_ids: list | None = None,
limit: int = 10,
):
start = time.perf_counter()
@ -41,7 +44,7 @@ class HitTestingService:
retrieval_model = dataset.retrieval_model or default_retrieval_model
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
if metadata_filtering_conditions and query:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
@ -66,6 +69,7 @@ class HitTestingService:
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
@ -80,17 +84,24 @@ class HitTestingService:
end = time.perf_counter()
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
dataset_queries = []
if query:
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
dataset_queries.append(content)
if attachment_ids:
for attachment_id in attachment_ids:
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
dataset_queries.append(content)
if dataset_queries:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents)
@ -168,9 +179,14 @@ class HitTestingService:
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
attachment_ids = args["attachment_ids"]
if not query or len(query) > 250:
raise ValueError("Query is required and cannot exceed 250 characters")
if not attachment_ids and not query:
raise ValueError("Query or attachment_ids is required")
if query and len(query) > 250:
raise ValueError("Query cannot exceed 250 characters")
if attachment_ids and not isinstance(attachment_ids, list):
raise ValueError("Attachment_ids must be a list")
@staticmethod
def escape_query_for_search(query: str) -> str:

View File

@ -38,21 +38,24 @@ class RagPipelineTaskProxy:
upload_file = FileService(db.engine).upload_text(
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
)
logger.info(
"tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities)
)
return upload_file.id
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to direct queue", upload_file_id)
logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id)
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to tenant queue", upload_file_id)
logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
logger.info("push tasks: %s", upload_file_id)
logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
@ -60,7 +63,7 @@ class RagPipelineTaskProxy:
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
logger.info("init tasks: %s", upload_file_id)
logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id)
def _send_to_default_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)

View File

@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
@ -21,9 +24,10 @@ class VectorService:
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
multimodal_documents: list[AttachmentDocument] = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
logger.warning(
@ -70,12 +74,29 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
documents.append(rag_document)
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_document: AttachmentDocument = AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
multimodal_documents.append(multimodal_document)
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
if len(multimodal_documents) > 0:
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
@classmethod
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
@ -130,6 +151,7 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
# use full doc mode to generate segment's child chunk
@ -226,3 +248,92 @@ class VectorService:
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
return
attachments = segment.attachments
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
# Check if there's any actual change needed
if set(attachment_ids) == set(old_attachment_ids):
return
try:
vector = Vector(dataset=dataset)
if dataset.is_multimodal:
# Delete old vectors if they exist
if old_attachment_ids:
vector.delete_by_ids(old_attachment_ids)
# Delete existing segment attachment bindings in one operation
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
synchronize_session=False
)
if not attachment_ids:
db.session.commit()
return
# Bulk fetch upload files - only fetch needed fields
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if not upload_file_list:
db.session.commit()
return
# Create a mapping for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
# Prepare batch operations
bindings = []
documents = []
# Create common metadata base to avoid repetition
base_metadata = {
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
}
# Process attachments in the order specified by attachment_ids
for attachment_id in attachment_ids:
upload_file = upload_file_map.get(attachment_id)
if not upload_file:
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
continue
# Create segment attachment binding
bindings.append(
SegmentAttachmentBinding(
tenant_id=segment.tenant_id,
dataset_id=segment.dataset_id,
document_id=segment.document_id,
segment_id=segment.id,
attachment_id=upload_file.id,
)
)
# Create document for vector indexing
documents.append(
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
)
# Bulk insert all bindings at once
if bindings:
db.session.add_all(bindings)
# Add documents to vector store if any
if documents and dataset.is_multimodal:
vector.add_texts(documents, duplicate_check=True)
# Single commit for all operations
db.session.commit()
except Exception:
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
db.session.rollback()
raise

View File

@ -4,9 +4,10 @@ import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str):
)
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str):
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()

View File

@ -18,6 +18,7 @@ from models.dataset import (
DatasetQuery,
Document,
DocumentSegment,
SegmentAttachmentBinding,
)
from models.model import UploadFile
@ -58,14 +59,20 @@ def clean_dataset_task(
)
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
doc_form = IndexType.PARAGRAPH_INDEX
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
@ -90,6 +97,7 @@ def clean_dataset_task(
for document in documents:
db.session.delete(document)
# delete document file
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
@ -107,6 +115,19 @@ def clean_dataset_task(
)
db.session.delete(image_file)
db.session.delete(segment)
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()

View File

@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
logger = logging.getLogger(__name__)
@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
raise Exception("Document has no dataset")
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).where(

View File

@ -4,9 +4,10 @@ import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)

View File

@ -1,14 +1,14 @@
import logging
import time
from typing import Literal
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
def deal_dataset_vector_index_task(dataset_id: str, action: str):
"""
Async deal dataset from index
:param dataset_id: dataset_id
@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)

View File

@ -6,14 +6,15 @@ from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_segment_from_index_task(
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None
):
"""
Async Remove segment from index
@ -49,6 +50,21 @@ def delete_segment_from_index_task(
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
db.session.delete(binding)
# delete upload file
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
db.session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))

View File

@ -8,7 +8,7 @@ from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
try:
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()

View File

@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue(
try:
_document_indexing(dataset_id, document_ids)
except Exception:
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
logger.exception(
"Error processing document indexing %s for tenant %s: %s",
dataset_id,
tenant_id,
document_ids,
exc_info=True,
)
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue(
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
if next_tasks:
for next_task in next_tasks:

View File

@ -1,13 +1,16 @@
import logging
import time
from collections.abc import Callable, Sequence
import click
from celery import shared_task
from sqlalchemy import select
from configs import dify_config
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
:param dataset_id:
:param document_ids:
.. warning:: TO BE DEPRECATED
This function will be deprecated and removed in a future version.
Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead.
Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids)
"""
logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids)
_duplicate_document_indexing_task(dataset_id, document_ids)
def _duplicate_document_indexing_task_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
try:
_duplicate_document_indexing_task(dataset_id, document_ids)
except Exception:
logger.exception(
"Error processing duplicate document indexing %s for tenant %s: %s",
dataset_id,
tenant_id,
document_ids,
exc_info=True,
)
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
documents = []
start_at = time.perf_counter()
@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
@shared_task(queue="dataset")
def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Async process duplicate documents
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_duplicate_document_indexing_task_with_tenant_queue(
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
)
@shared_task(queue="priority_dataset")
def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Async process duplicate documents
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_duplicate_document_indexing_task_with_tenant_queue(
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
)

View File

@ -4,9 +4,10 @@ import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str):
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str):
)
child_documents.append(child_document)
document.children = child_documents
multimodel_documents = []
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodel_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
# save vector index
index_processor.load(dataset, [document])
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))

View File

@ -5,9 +5,10 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
try:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents)
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))

View File

@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task(
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities))
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task(
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task(
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:

View File

@ -47,6 +47,8 @@ def rag_pipeline_run_task(
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities))
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
@ -66,7 +68,7 @@ def rag_pipeline_run_task(
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
@ -78,7 +80,7 @@ def rag_pipeline_run_task(
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:

View File

@ -0,0 +1,244 @@
"""Integration tests for Trigger Provider subscription permission verification."""
import uuid
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.workspace import trigger_providers as trigger_providers_api
from libs.datetime_utils import naive_utc_now
from models import Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
class TestTriggerProviderSubscriptionPermissions:
"""Test permission verification for Trigger Provider subscription endpoints."""
@pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account(name="Test User", email="test@example.com")
account.id = str(uuid.uuid4())
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4())
mock_session_instance = mock.Mock()
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
mock_scalars_result = mock.Mock()
mock_scalars_result.one.return_value = tenant
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
mock_session_context = mock.Mock()
mock_session_context.__enter__.return_value = mock_session_instance
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
account.current_tenant = tenant
account.current_tenant_id = tenant.id
return account
@pytest.mark.parametrize(
("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"),
[
# Admin/Owner can do everything
(TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200),
(TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200),
# Editor can list, get, update (parameters), but not create, build, or delete
(TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403),
# Normal user cannot do anything
(TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403),
# Dataset operator cannot do anything
(TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403),
],
)
def test_trigger_subscription_permissions(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_account,
role: TenantAccountRole,
list_status: int,
get_status: int,
update_status: int,
create_status: int,
build_status: int,
delete_status: int,
):
"""Test that different roles have appropriate permissions for trigger subscription operations."""
# Set user role
mock_account.role = role
# Mock current user
monkeypatch.setattr(trigger_providers_api, "current_user", mock_account)
# Mock AccountService.load_user to prevent authentication issues
from services.account_service import AccountService
mock_load_user = mock.Mock(return_value=mock_account)
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
# Test data
provider = "some_provider/some_trigger"
subscription_builder_id = str(uuid.uuid4())
subscription_id = str(uuid.uuid4())
# Mock service methods
mock_list_subscriptions = mock.Mock(return_value=[])
monkeypatch.setattr(
"services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions",
mock_list_subscriptions,
)
mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id})
monkeypatch.setattr(
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
mock_get_subscription_builder,
)
mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id})
monkeypatch.setattr(
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
mock_update_subscription_builder,
)
mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id})
monkeypatch.setattr(
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
mock_create_subscription_builder,
)
mock_update_and_build_builder = mock.Mock()
monkeypatch.setattr(
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder",
mock_update_and_build_builder,
)
mock_delete_provider = mock.Mock()
mock_delete_plugin_trigger = mock.Mock()
mock_db_session = mock.Mock()
mock_db_session.commit = mock.Mock()
def mock_session_func(engine=None):
return mock_session_context
mock_session_context = mock.Mock()
mock_session_context.__enter__.return_value = mock_db_session
mock_session_context.__exit__.return_value = None
monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func)
monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func)
monkeypatch.setattr(
"services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider",
mock_delete_provider,
)
monkeypatch.setattr(
"services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription",
mock_delete_plugin_trigger,
)
# Test 1: List subscriptions (should work for Editor, Admin, Owner)
response = test_client.get(
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list",
headers=auth_header,
)
assert response.status_code == list_status
# Test 2: Get subscription builder (should work for Editor, Admin, Owner)
response = test_client.get(
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}",
headers=auth_header,
)
assert response.status_code == get_status
# Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner)
response = test_client.post(
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}",
headers=auth_header,
json={"parameters": {"webhook_url": "https://example.com/webhook"}},
)
assert response.status_code == update_status
# Test 4: Create subscription builder (should only work for Admin, Owner)
response = test_client.post(
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create",
headers=auth_header,
json={"credential_type": "api_key"},
)
assert response.status_code == create_status
# Test 5: Build/activate subscription (should only work for Admin, Owner)
response = test_client.post(
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}",
headers=auth_header,
json={"name": "Test Subscription"},
)
assert response.status_code == build_status
# Test 6: Delete subscription (should only work for Admin, Owner)
response = test_client.post(
f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete",
headers=auth_header,
)
assert response.status_code == delete_status
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
# Editor should be able to access logs for debugging
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_trigger_subscription_logs_permissions(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Test that different roles have appropriate permissions for accessing subscription logs."""
# Set user role
mock_account.role = role
# Mock current user
monkeypatch.setattr(trigger_providers_api, "current_user", mock_account)
# Mock AccountService.load_user to prevent authentication issues
from services.account_service import AccountService
mock_load_user = mock.Mock(return_value=mock_account)
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
# Test data
provider = "some_provider/some_trigger"
subscription_builder_id = str(uuid.uuid4())
# Mock service method
mock_list_logs = mock.Mock(return_value=[])
monkeypatch.setattr(
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs",
mock_list_logs,
)
# Test access to logs
response = test_client.get(
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}",
headers=auth_header,
)
assert response.status_code == status

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask:
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask:
# Assert: Verify the expected outcomes
# Verify index processor was called correctly
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify database state changes
@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask:
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.QA_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask:
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
IndexStructureType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask:
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify index processing occurred with all completed segments
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
# Assert: Verify index processing occurred but with empty documents list
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with all completed segments
@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask:
assert len(remaining_logs) == 0
# Verify index processing occurred normally
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify segments were enabled
@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify only eligible segments were processed
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters

View File

@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask:
document.updated_at = fake.date_time_this_year()
document.doc_type = kwargs.get("doc_type", "text")
document.doc_metadata = kwargs.get("doc_metadata", {})
document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX)
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
document.doc_language = kwargs.get("doc_language", "en")
db_session_with_containers.add(document)
@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask:
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Extract segment IDs for the task
segment_ids = [segment.id for segment in segments]
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None # Task should return None on success
@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask:
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent dataset
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id)
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, [])
# Verify the task completed without exceptions
assert result is None # Task should return None when dataset not found
@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask:
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent document
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, [])
# Verify the task completed without exceptions
assert result is None # Task should return None when document not found
@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with disabled document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is disabled
@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with archived document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is archived
@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with incomplete indexing
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when indexing is not completed
@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask:
fake = Faker()
# Test different document forms
document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX]
document_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in document_forms:
# Create test data for each document form
@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None
@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor to raise an exception
mock_processor = MagicMock()
@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task - should not raise exception
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without raising exceptions
assert result is None # Task should return None even when exceptions occur
@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, [])
# Verify the task completed successfully
assert result is None
@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask:
# Create large number of segments
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None

View File

@ -0,0 +1,763 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
_duplicate_document_indexing_task, # Core function
_duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function
duplicate_document_indexing_task, # Deprecated old interface
normal_duplicate_document_indexing_task, # New normal task
priority_duplicate_document_indexing_task, # New priority task
)
class TestDuplicateDocumentIndexingTasks:
"""Integration tests for duplicate document indexing tasks using testcontainers.
This test class covers:
- Core _duplicate_document_indexing_task function
- Deprecated duplicate_document_indexing_task function
- New normal_duplicate_document_indexing_task function
- New priority_duplicate_document_indexing_task function
- Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function
- Document segment cleanup logic
"""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner,
patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service,
patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory,
):
# Setup mock indexing runner
mock_runner_instance = MagicMock()
mock_indexing_runner.return_value = mock_runner_instance
# Setup mock feature service
mock_features = MagicMock()
mock_features.billing.enabled = False
mock_feature_service.get_features.return_value = mock_features
# Setup mock index processor factory
mock_processor = MagicMock()
mock_processor.clean = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
yield {
"indexing_runner": mock_indexing_runner,
"indexing_runner_instance": mock_runner_instance,
"feature_service": mock_feature_service,
"features": mock_features,
"index_processor_factory": mock_index_processor_factory,
"index_processor": mock_processor,
}
def _create_test_dataset_and_documents(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
):
"""
Helper method to create a test dataset and documents for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
document_count: Number of documents to create
Returns:
tuple: (dataset, documents) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create documents
documents = []
for i in range(document_count):
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db.session.add(document)
documents.append(document)
db.session.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
return dataset, documents
def _create_test_dataset_with_segments(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2
):
"""
Helper method to create a test dataset with documents and segments.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
document_count: Number of documents to create
segments_per_doc: Number of segments per document
Returns:
tuple: (dataset, documents, segments) - Created dataset, documents and segments
"""
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count
)
fake = Faker()
segments = []
# Create segments for each document
for document in documents:
for i in range(segments_per_doc):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
index_node_id=f"{document.id}-node-{i}",
index_node_hash=fake.sha256(),
content=fake.text(max_nb_chars=200),
word_count=50,
tokens=100,
status="completed",
enabled=True,
indexing_at=fake.date_time_this_year(),
created_by=dataset.created_by, # Add required field
)
db.session.add(segment)
segments.append(segment)
db.session.commit()
# Refresh to ensure all relationships are loaded
for document in documents:
db.session.refresh(document)
return dataset, documents, segments
def _create_test_dataset_with_billing_features(
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
):
"""
Helper method to create a test dataset with billing features configured.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
billing_enabled: Whether billing is enabled
Returns:
tuple: (dataset, documents) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create documents
documents = []
for i in range(3):
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db.session.add(document)
documents.append(document)
db.session.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
if billing_enabled:
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
mock_external_service_dependencies["features"].vector_space.limit = 100
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
return dataset, documents
def test_duplicate_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful duplicate document indexing with multiple documents.
This test verifies:
- Proper dataset retrieval from database
- Correct document processing and status updates
- IndexingRunner integration
- Database state updates
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=3
)
document_ids = [doc.id for doc in documents]
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 3
def test_duplicate_document_indexing_task_with_segment_cleanup(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test duplicate document indexing with existing segments that need cleanup.
This test verifies:
- Old segments are identified and cleaned
- Index processor clean method is called
- Segments are deleted from database
- New indexing proceeds after cleanup
"""
# Arrange: Create test data with existing segments
dataset, documents, segments = self._create_test_dataset_with_segments(
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
)
document_ids = [doc.id for doc in documents]
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify segment cleanup
# Verify index processor clean was called for each document with segments
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
# Verify segments were deleted from database
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
for segment in segments:
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
assert deleted_segment is None
# Verify documents were updated to parsing status
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify indexing runner was called
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_duplicate_document_indexing_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
This test verifies:
- Proper error handling for missing datasets
- Early return without processing
- Database session cleanup
- No unnecessary indexing runner calls
"""
# Arrange: Use non-existent dataset ID
fake = Faker()
non_existent_dataset_id = fake.uuid4()
document_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent dataset
_duplicate_document_indexing_task(non_existent_dataset_id, document_ids)
# Assert: Verify no processing occurred
mock_external_service_dependencies["indexing_runner"].assert_not_called()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
def test_duplicate_document_indexing_task_document_not_found_in_dataset(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling when some documents don't exist in the dataset.
This test verifies:
- Only existing documents are processed
- Non-existent documents are ignored
- Indexing runner receives only valid documents
- Database state updates correctly
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
# Mix existing and non-existent document IDs
fake = Faker()
existing_document_ids = [doc.id for doc in documents]
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
all_document_ids = existing_document_ids + non_existent_document_ids
# Act: Execute the task with mixed document IDs
_duplicate_document_indexing_task(dataset.id, all_document_ids)
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only existing documents were updated
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify the run method was called with only existing documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 2 # Only existing documents
def test_duplicate_document_indexing_task_indexing_runner_exception(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of IndexingRunner exceptions.
This test verifies:
- Exceptions from IndexingRunner are properly caught
- Task completes without raising exceptions
- Database session is properly closed
- Error logging occurs
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock IndexingRunner to raise an exception
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
"Indexing runner failed"
)
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _duplicate_document_indexing_task close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test billing validation for sandbox plan batch upload limit.
This test verifies:
- Sandbox plan batch upload limit enforcement
- Error handling for batch upload limit exceeded
- Document status updates to error state
- Proper error message recording
"""
# Arrange: Create test data with billing enabled
dataset, documents = self._create_test_dataset_with_billing_features(
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
)
# Configure sandbox plan with batch limit
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
# Create more documents than sandbox plan allows (limit is 1)
fake = Faker()
extra_documents = []
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
document = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=i + 3,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db.session.add(document)
extra_documents.append(document)
db.session.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error.lower()
assert updated_document.stopped_at is not None
# Verify indexing runner was not called due to early validation error
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test billing validation for vector space limit.
This test verifies:
- Vector space limit enforcement
- Error handling for vector space limit exceeded
- Document status updates to error state
- Proper error message recording
"""
# Arrange: Create test data with billing enabled
dataset, documents = self._create_test_dataset_with_billing_features(
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
)
# Configure TEAM plan with vector space limit exceeded
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM
mock_external_service_dependencies["features"].vector_space.limit = 100
mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit
document_ids = [doc.id for doc in documents] # 3 documents will exceed limit
# Act: Execute the task with documents that will exceed vector space limit
_duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "limit" in updated_document.error.lower()
assert updated_document.stopped_at is not None
# Verify indexing runner was not called due to early validation error
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_duplicate_document_indexing_task_with_empty_document_list(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of empty document list.
This test verifies:
- Empty document list is handled gracefully
- No processing occurs
- No errors are raised
- Database session is properly closed
"""
# Arrange: Create test dataset
dataset, _ = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=0
)
document_ids = []
# Act: Execute the task with empty document list
_duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify IndexingRunner was called with empty list
# Note: The actual implementation does call run([]) with empty list
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([])
def test_deprecated_duplicate_document_indexing_task_delegates_to_core(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that deprecated duplicate_document_indexing_task delegates to core function.
This test verifies:
- Deprecated function calls core _duplicate_document_indexing_task
- Proper parameter passing
- Backward compatibility
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Act: Execute the deprecated task
duplicate_document_indexing_task(dataset.id, document_ids)
# Assert: Verify core function was executed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
):
"""
Test normal_duplicate_document_indexing_task with tenant isolation queue.
This test verifies:
- Task uses tenant isolation queue correctly
- Core processing function is called
- Queue management (pull tasks, delete key) works properly
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock tenant isolated queue to return no next tasks
mock_queue = MagicMock()
mock_queue.pull_tasks.return_value = []
mock_queue_class.return_value = mock_queue
# Act: Execute the normal task
normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify tenant queue was used
mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing")
mock_queue.pull_tasks.assert_called_once()
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
):
"""
Test priority_duplicate_document_indexing_task with tenant isolation queue.
This test verifies:
- Task uses tenant isolation queue correctly
- Core processing function is called
- Queue management works properly
- Same behavior as normal task with different queue assignment
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock tenant isolated queue to return no next tasks
mock_queue = MagicMock()
mock_queue.pull_tasks.return_value = []
mock_queue_class.return_value = mock_queue
# Act: Execute the priority task
priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify tenant queue was used
mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing")
mock_queue.pull_tasks.assert_called_once()
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
def test_tenant_queue_wrapper_processes_next_tasks(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant queue wrapper processes next queued tasks.
This test verifies:
- After completing current task, next tasks are pulled from queue
- Next tasks are executed correctly
- Task waiting time is set for next tasks
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Extract values before session detachment
tenant_id = dataset.tenant_id
dataset_id = dataset.id
# Mock tenant isolated queue to return next task
mock_queue = MagicMock()
next_task = {
"tenant_id": tenant_id,
"dataset_id": dataset_id,
"document_ids": document_ids,
}
mock_queue.pull_tasks.return_value = [next_task]
mock_queue_class.return_value = mock_queue
# Mock the task function to track calls
mock_task_func = MagicMock()
# Act: Execute the wrapper function
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
# Assert: Verify next task was scheduled
mock_queue.pull_tasks.assert_called_once()
mock_queue.set_task_waiting_time.assert_called_once()
mock_task_func.delay.assert_called_once_with(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_ids=document_ids,
)
mock_queue.delete_task_key.assert_not_called()

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask:
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask:
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask:
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.QA_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask:
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
# Assert: Verify index processor was created but load was not called
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_with_parent_child_structure(
@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask:
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask:
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
IndexStructureType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()

View File

@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeConnectionError,
@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments:
@pytest.fixture
def sample_embedding_result(self):
"""Create a sample TextEmbeddingResult for testing.
"""Create a sample EmbeddingResult for testing.
Returns:
TextEmbeddingResult: Mock embedding result with proper structure
EmbeddingResult: Mock embedding result with proper structure
"""
# Create normalized embedding vectors (dimension 1536 for ada-002)
embedding_vector = np.random.randn(1536)
@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_vector],
usage=usage,
@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments:
latency=0.6,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=new_embeddings,
usage=usage,
@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[valid_vector.tolist(), nan_vector],
usage=usage,
@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[nan_vector],
usage=usage,
@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage,
)
result_3_small = TextEmbeddingResult(
result_3_small = EmbeddingResult(
model="text-embedding-3-small",
embeddings=[normalized_3_small],
usage=usage,
@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching:
latency=0.4,
)
result_openai = TextEmbeddingResult(
result_openai = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_openai],
usage=usage_openai,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation:
latency=0.7,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage_ada,
@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation:
latency=0.4,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases:
latency=0.1,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases:
latency=1.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases:
latency=0.2,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases:
)
# Model returns embeddings for all texts
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,

View File

@ -62,7 +62,7 @@ from core.indexing_runner import (
IndexingRunner,
)
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DatasetProcessRule
@ -112,7 +112,7 @@ def create_mock_dataset_document(
document_id: str | None = None,
dataset_id: str | None = None,
tenant_id: str | None = None,
doc_form: str = IndexType.PARAGRAPH_INDEX,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
data_source_type: str = "upload_file",
doc_language: str = "English",
) -> Mock:
@ -133,8 +133,8 @@ def create_mock_dataset_document(
Mock: A configured mock DatasetDocument object with all required attributes.
Example:
>>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX)
>>> assert doc.doc_form == IndexType.QA_INDEX
>>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX)
>>> assert doc.doc_form == IndexStructureType.QA_INDEX
"""
doc = Mock(spec=DatasetDocument)
doc.id = document_id or str(uuid.uuid4())
@ -276,7 +276,7 @@ class TestIndexingRunnerExtract:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
return doc
@ -616,7 +616,7 @@ class TestIndexingRunnerLoad:
doc = Mock(spec=DatasetDocument)
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@ -700,7 +700,7 @@ class TestIndexingRunnerLoad:
"""Test loading with parent-child index structure."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
sample_dataset.indexing_technique = "high_quality"
# Add child documents
@ -775,7 +775,7 @@ class TestIndexingRunnerRun:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.doc_language = "English"
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
@ -802,6 +802,21 @@ class TestIndexingRunnerRun:
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
# Mock current_user (Account) for _transform
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
# Setup db.session.query to return different results based on the model
def mock_query_side_effect(model):
mock_query_result = MagicMock()
if model.__name__ == "Dataset":
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
elif model.__name__ == "Account":
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
return mock_query_result
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
# Mock processor
mock_processor = MagicMock()
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.created_by = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments:
"""Test loading segments for parent-child index."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
# Add child documents
for doc in sample_documents:
@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate:
tenant_id=tenant_id,
extract_settings=extract_settings,
tmp_processing_rule={"mode": "automatic", "rules": {}},
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)

View File

@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.rerank.weight_rerank import WeightRerankRunner
def create_mock_model_instance():
"""Create a properly configured mock ModelInstance for reranking tests."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
class TestRerankModelRunner:
"""Unit tests for RerankModelRunner.
@ -37,10 +49,23 @@ class TestRerankModelRunner:
- Metadata preservation and score injection
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for reranking."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
@pytest.fixture
@ -803,7 +828,7 @@ class TestRerankRunnerFactory:
- Parameters are forwarded to runner constructor
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner via factory
runner = RerankRunnerFactory.create_rerank_runner(
@ -865,7 +890,7 @@ class TestRerankRunnerFactory:
- String values are properly matched
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner using enum value
runner = RerankRunnerFactory.create_rerank_runner(
@ -886,6 +911,13 @@ class TestRerankIntegration:
- Real-world usage scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_model_reranking_full_workflow(self):
"""Test complete model-based reranking workflow.
@ -895,7 +927,7 @@ class TestRerankIntegration:
- Top results are returned correctly
"""
# Arrange: Create mock model and documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -951,7 +983,7 @@ class TestRerankIntegration:
- Normalization is consistent
"""
# Arrange: Create mock model with various scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -991,6 +1023,13 @@ class TestRerankEdgeCases:
- Concurrent reranking scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_with_empty_metadata(self):
"""Test reranking when documents have empty metadata.
@ -1000,7 +1039,7 @@ class TestRerankEdgeCases:
- Empty metadata documents are processed correctly
"""
# Arrange: Create documents with empty metadata
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1046,7 +1085,7 @@ class TestRerankEdgeCases:
- Score comparison logic works at boundary
"""
# Arrange: Create mock with various scores including negatives
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1082,7 +1121,7 @@ class TestRerankEdgeCases:
- No overflow or precision issues
"""
# Arrange: All documents with perfect scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1117,7 +1156,7 @@ class TestRerankEdgeCases:
- Content encoding is preserved
"""
# Arrange: Documents with special characters
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1159,7 +1198,7 @@ class TestRerankEdgeCases:
- Content is not truncated unexpectedly
"""
# Arrange: Documents with very long content
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
long_content = "This is a very long document. " * 1000 # ~30,000 characters
mock_rerank_result = RerankResult(
@ -1196,7 +1235,7 @@ class TestRerankEdgeCases:
- All documents are processed correctly
"""
# Arrange: Create 100 documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
num_docs = 100
# Create rerank results for all documents
@ -1287,7 +1326,7 @@ class TestRerankEdgeCases:
- Documents can still be ranked
"""
# Arrange: Empty query
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@ -1325,6 +1364,13 @@ class TestRerankPerformance:
- Score calculation optimization
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_batch_processing(self):
"""Test that documents are processed in a single batch.
@ -1334,7 +1380,7 @@ class TestRerankPerformance:
- Efficient batch processing
"""
# Arrange: Multiple documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
@ -1435,6 +1481,13 @@ class TestRerankErrorHandling:
- Error propagation
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_model_invocation_error(self):
"""Test handling of model invocation errors.
@ -1444,7 +1497,7 @@ class TestRerankErrorHandling:
- Error context is preserved
"""
# Arrange: Mock model that raises exception
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
documents = [
@ -1470,7 +1523,7 @@ class TestRerankErrorHandling:
- Invalid results don't corrupt output
"""
# Arrange: Rerank result with invalid index
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[

View File

@ -425,15 +425,15 @@ class TestRetrievalService:
# ==================== Vector Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic vector/semantic search functionality.
This test validates the core vector search flow:
1. Dataset is retrieved from database
2. embedding_search is called via ThreadPoolExecutor
2. _retrieve is called via ThreadPoolExecutor
3. Documents are added to shared all_documents list
4. Results are returned to caller
@ -447,28 +447,28 @@ class TestRetrievalService:
# Set up the mock dataset that will be "retrieved" from database
mock_get_dataset.return_value = mock_dataset
# Create a side effect function that simulates embedding_search behavior
# In the real implementation, embedding_search:
# 1. Gets the dataset
# 2. Creates a Vector instance
# 3. Calls search_by_vector with embeddings
# 4. Extends all_documents with results
def side_effect_embedding_search(
# Create a side effect function that simulates _retrieve behavior
# _retrieve modifies the all_documents list in place
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
"""Simulate embedding_search adding documents to the shared list."""
all_documents.extend(sample_documents)
"""Simulate _retrieve adding documents to the shared list."""
if all_documents is not None:
all_documents.extend(sample_documents)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
# Define test parameters
query = "What is Python?" # Natural language query
@ -481,7 +481,7 @@ class TestRetrievalService:
# 1. Check if query is empty (early return if so)
# 2. Get the dataset using _get_dataset
# 3. Create ThreadPoolExecutor
# 4. Submit embedding_search task
# 4. Submit _retrieve task
# 5. Wait for completion
# 6. Return all_documents list
results = RetrievalService.retrieve(
@ -502,15 +502,13 @@ class TestRetrievalService:
# Verify documents maintain their scores (highest score first in sample_documents)
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
# Verify embedding_search was called exactly once
# Verify _retrieve was called exactly once
# This confirms the search method was invoked by ThreadPoolExecutor
mock_embedding_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_document_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with document ID filtering.
@ -522,21 +520,25 @@ class TestRetrievalService:
mock_get_dataset.return_value = mock_dataset
filtered_docs = [sample_documents[0]]
def side_effect_embedding_search(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(filtered_docs)
if all_documents is not None:
all_documents.extend(filtered_docs)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
document_ids_filter = [sample_documents[0].metadata["document_id"]]
# Act
@ -552,12 +554,12 @@ class TestRetrievalService:
assert len(results) == 1
assert results[0].metadata["doc_id"] == "doc1"
# Verify document_ids_filter was passed
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["document_ids_filter"] == document_ids_filter
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search when no results match the query.
@ -567,8 +569,8 @@ class TestRetrievalService:
"""
# Arrange
mock_get_dataset.return_value = mock_dataset
# embedding_search doesn't add anything to all_documents
mock_embedding_search.side_effect = lambda *args, **kwargs: None
# _retrieve doesn't add anything to all_documents
mock_retrieve.side_effect = lambda *args, **kwargs: None
# Act
results = RetrievalService.retrieve(
@ -583,9 +585,9 @@ class TestRetrievalService:
# ==================== Keyword Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic keyword search functionality.
@ -597,12 +599,25 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
def side_effect_keyword_search(
flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
def side_effect_retrieve(
flask_app,
retrieval_method,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(sample_documents)
if all_documents is not None:
all_documents.extend(sample_documents)
mock_keyword_search.side_effect = side_effect_keyword_search
mock_retrieve.side_effect = side_effect_retrieve
query = "Python programming"
top_k = 3
@ -618,7 +633,7 @@ class TestRetrievalService:
# Assert
assert len(results) == 3
assert all(isinstance(doc, Document) for doc in results)
mock_keyword_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
@ -1147,11 +1162,9 @@ class TestRetrievalService:
# ==================== Metadata Filtering Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_metadata_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with metadata-based document filtering.
@ -1166,21 +1179,25 @@ class TestRetrievalService:
filtered_doc = sample_documents[0]
filtered_doc.metadata["category"] = "programming"
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(filtered_doc)
if all_documents is not None:
all_documents.append(filtered_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
# Act
results = RetrievalService.retrieve(
@ -1243,9 +1260,9 @@ class TestRetrievalService:
# Assert
assert results == []
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that exceptions during retrieval are properly handled.
@ -1256,22 +1273,26 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
# Make embedding_search add an exception to the exceptions list
# Make _retrieve add an exception to the exceptions list
def side_effect_with_exception(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
exceptions.append("Search failed")
if exceptions is not None:
exceptions.append("Search failed")
mock_embedding_search.side_effect = side_effect_with_exception
mock_retrieve.side_effect = side_effect_with_exception
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@ -1286,9 +1307,9 @@ class TestRetrievalService:
# ==================== Score Threshold Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search with score threshold filtering.
@ -1306,21 +1327,25 @@ class TestRetrievalService:
provider="dify",
)
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(high_score_doc)
if all_documents is not None:
all_documents.append(high_score_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
score_threshold = 0.8
@ -1339,9 +1364,9 @@ class TestRetrievalService:
# ==================== Top-K Limiting Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that retrieval respects top_k parameter.
@ -1362,22 +1387,26 @@ class TestRetrievalService:
for i in range(10)
]
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# Return only top_k documents
all_documents.extend(many_docs[:top_k])
if all_documents is not None:
all_documents.extend(many_docs[:top_k])
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
top_k = 3
@ -1390,9 +1419,9 @@ class TestRetrievalService:
)
# Assert
# Verify top_k was passed to embedding_search
assert mock_embedding_search.called
call_kwargs = mock_embedding_search.call_args.kwargs
# Verify _retrieve was called
assert mock_retrieve.called
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["top_k"] == top_k
# Verify we got the right number of results
assert len(results) == top_k
@ -1421,11 +1450,9 @@ class TestRetrievalService:
# ==================== Reranking Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_semantic_search_with_reranking(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test semantic search with reranking model.
@ -1439,22 +1466,26 @@ class TestRetrievalService:
# Simulate reranking changing order
reranked_docs = list(reversed(sample_documents))
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# embedding_search handles reranking internally
all_documents.extend(reranked_docs)
# _retrieve handles reranking internally
if all_documents is not None:
all_documents.extend(reranked_docs)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
reranking_model = {
"reranking_provider_name": "cohere",
@ -1473,7 +1504,7 @@ class TestRetrievalService:
# Assert
# For semantic search with reranking, reranking_model should be passed
assert len(results) == 3
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["reranking_model"] == reranking_model

View File

@ -1,3 +1,4 @@
import json
from unittest.mock import Mock, PropertyMock, patch
import httpx
@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response):
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file
# UTF-8 Encoding Tests
@pytest.mark.parametrize(
("content_bytes", "expected_text", "description"),
[
# Chinese UTF-8 bytes
(
b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}',
'{"message": "你好世界"}',
"Chinese characters UTF-8",
),
# Japanese UTF-8 bytes
(
b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}',
'{"message": "こんにちは"}',
"Japanese characters UTF-8",
),
# Korean UTF-8 bytes
(
b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}',
'{"message": "안녕하세요"}',
"Korean characters UTF-8",
),
# Arabic UTF-8
(b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"),
# European characters UTF-8
(b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"),
# Simple ASCII
(b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"),
],
)
def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description):
"""Test that Response.text properly decodes UTF-8 content with charset_normalizer"""
mock_response.headers = {"content-type": "application/json; charset=utf-8"}
type(mock_response).content = PropertyMock(return_value=content_bytes)
# Mock httpx response.text to return something different (simulating potential encoding issues)
mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property
response = Response(mock_response)
# Our enhanced text property should decode properly using charset_normalizer
assert response.text == expected_text, (
f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}"
)
def test_text_property_fallback_to_httpx(mock_response):
"""Test that Response.text falls back to httpx.text when charset_normalizer fails"""
mock_response.headers = {"content-type": "application/json"}
# Create malformed UTF-8 bytes
malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}'
type(mock_response).content = PropertyMock(return_value=malformed_bytes)
# Mock httpx.text to return some fallback value
fallback_text = '{"text": "fallback"}'
mock_response.text = fallback_text
response = Response(mock_response)
# Should fall back to httpx's text when charset_normalizer fails
assert response.text == fallback_text
@pytest.mark.parametrize(
("json_content", "description"),
[
# JSON with escaped Unicode (like Flask jsonify())
('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"),
# JSON with mixed escape sequences and UTF-8
('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"),
# JSON with complex escape sequences
('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"),
],
)
def test_text_property_with_escaped_unicode(mock_response, json_content, description):
"""Test Response.text with JSON containing Unicode escape sequences"""
mock_response.headers = {"content-type": "application/json"}
content_bytes = json_content.encode("utf-8")
type(mock_response).content = PropertyMock(return_value=content_bytes)
mock_response.text = json_content # httpx would return the same for valid UTF-8
response = Response(mock_response)
# Should preserve the escape sequences (valid JSON)
assert response.text == json_content, f"Failed for {description}"
# The text should be valid JSON that can be parsed back to proper Unicode
parsed = json.loads(response.text)
assert isinstance(parsed, dict), f"Invalid JSON for {description}"

View File

@ -117,7 +117,7 @@ import pytest
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
# ============================================================================
# Test Data Factory
@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy:
# Features Property Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""
Test cached_property features.
@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy:
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_features_property_with_different_tenants(self, mock_feature_service):
"""
Test features property with different tenant IDs.
@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy:
# Direct Queue Routing Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""
Test _send_to_direct_queue method.
@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_direct_queue_with_priority_task(self, mock_task):
"""
Test _send_to_direct_queue with priority task function.
@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue_with_single_document(self, mock_task):
"""
Test _send_to_direct_queue with single document ID.
@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue_with_empty_documents(self, mock_task):
"""
Test _send_to_direct_queue with empty document_ids list.
@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy:
# Tenant Queue Routing Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""
Test _send_to_tenant_queue when task key exists.
@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy:
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""
Test _send_to_tenant_queue when no task key exists.
@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy:
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_tenant_queue_with_priority_task(self, mock_task):
"""
Test _send_to_tenant_queue with priority task function.
@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_document_task_serialization(self, mock_task):
"""
Test DocumentTask serialization in _send_to_tenant_queue.
@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy:
# Queue Type Selection Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""
Test _send_to_default_tenant_queue method.
@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""
Test _send_to_priority_tenant_queue method.
@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""
Test _send_to_priority_direct_queue method.
@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy:
# Dispatch Logic Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""
Test _dispatch method when billing is enabled with SANDBOX plan.
@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service):
"""
Test _dispatch method when billing is enabled with TEAM plan.
@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service):
"""
Test _dispatch method when billing is enabled with PROFESSIONAL plan.
@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""
Test _dispatch method when billing is disabled.
@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""
Test _dispatch method with empty plan string.
@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""
Test _dispatch method with None plan.
@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy:
# Delay Method Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_delay_method(self, mock_feature_service):
"""
Test delay method integration.
@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_delay_method_with_team_plan(self, mock_feature_service):
"""
Test delay method with TEAM plan.
@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_delay_method_with_billing_disabled(self, mock_feature_service):
"""
Test delay method with billing disabled.
@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy:
# Batch Operations Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_batch_operation_with_multiple_documents(self, mock_task):
"""
Test batch operation with multiple documents.
@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_batch_operation_with_large_batch(self, mock_task):
"""
Test batch operation with large batch of documents.
@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy:
# Error Handling Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue_task_delay_failure(self, mock_task):
"""
Test _send_to_direct_queue when task.delay() raises an exception.
@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy:
with pytest.raises(Exception, match="Task delay failed"):
proxy._send_to_direct_queue(mock_task)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_push_tasks_failure(self, mock_task):
"""
Test _send_to_tenant_queue when push_tasks raises an exception.
@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy:
with pytest.raises(Exception, match="Push tasks failed"):
proxy._send_to_tenant_queue(mock_task)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task):
"""
Test _send_to_tenant_queue when set_task_waiting_time raises an exception.
@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy:
with pytest.raises(Exception, match="Set waiting time failed"):
proxy._send_to_tenant_queue(mock_task)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
def test_dispatch_feature_service_failure(self, mock_feature_service):
"""
Test _dispatch when FeatureService.get_features raises an exception.
@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy:
# Integration Tests
# ========================================================================
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service):
"""
Test full flow for SANDBOX plan with tenant queue.
@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_full_flow_team_plan(self, mock_task, mock_feature_service):
"""
Test full flow for TEAM plan with priority tenant queue.
@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
def test_full_flow_billing_disabled(self, mock_task, mock_feature_service):
"""
Test full flow for billing disabled (self-hosted/enterprise).

View File

@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy:
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy:
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy:
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy:
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy:
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
def test_send_to_default_tenant_queue(self):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy:
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
def test_send_to_priority_tenant_queue(self):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy:
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
def test_send_to_priority_direct_queue(self):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy:
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy:
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy:
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy:
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy:
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange

View File

@ -0,0 +1,363 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import (
DuplicateDocumentIndexingTaskProxy,
)
class DuplicateDocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_duplicate_document_task_proxy(
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
) -> DuplicateDocumentIndexingTaskProxy:
"""Create DuplicateDocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDuplicateDocumentIndexingTaskProxy:
"""Test cases for DuplicateDocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DuplicateDocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing"
def test_queue_name(self):
"""Test QUEUE_NAME class variable."""
# Arrange & Act
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
# Assert
assert proxy.QUEUE_NAME == "duplicate_document_indexing"
def test_task_functions(self):
"""Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables."""
# Arrange & Act
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
# Assert
assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task"
assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task"
@patch("services.document_indexing_proxy.base.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch(
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
)
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch(
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
)
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch(
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
)
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
def test_send_to_default_tenant_queue(self):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
def test_send_to_priority_tenant_queue(self):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
def test_send_to_priority_direct_queue(self):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# Assert
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=""
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=None
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_large_batch(self):
"""Test initialization with large batch of document IDs."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = [f"doc-{i}" for i in range(100)]
# Act
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert len(proxy._document_ids) == 100
@patch("services.document_indexing_proxy.base.FeatureService")
def test_dispatch_with_professional_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with professional plan."""
# Arrange
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.PROFESSIONAL
)
mock_feature_service.get_features.return_value = mock_features
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()

View File

@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage.
"""
import json
import re
from datetime import datetime
from unittest.mock import MagicMock, Mock, patch
@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval:
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory):
"""Test retrieval returns empty list on non-200 status."""
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory):
"""Test that non-200 status code raises Exception with response text."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error: Database connection failed"
mock_process.return_value = mock_response
# Act
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
"tenant-123", "dataset-123", "query", {"top_k": 5}
)
# Act & Assert
with pytest.raises(Exception, match="Internal Server Error: Database connection failed"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
"tenant-123", "dataset-123", "query", {"top_k": 5}
)
# Assert
assert result == []
@pytest.mark.parametrize(
("status_code", "error_message"),
[
(400, "Bad Request: Invalid query parameters"),
(401, "Unauthorized: Invalid API key"),
(403, "Forbidden: Access denied to resource"),
(404, "Not Found: Knowledge base not found"),
(429, "Too Many Requests: Rate limit exceeded"),
(500, "Internal Server Error: Database connection failed"),
(502, "Bad Gateway: External service unavailable"),
(503, "Service Unavailable: Maintenance mode"),
],
)
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_various_error_status_codes(
self, mock_db, mock_process, factory, status_code, error_message
):
"""Test that various error status codes raise exceptions with response text."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-123"
binding = factory.create_external_knowledge_binding_mock(
dataset_id=dataset_id, external_knowledge_api_id="api-123"
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = error_message
mock_process.return_value = mock_response
# Act & Assert
with pytest.raises(ValueError, match=re.escape(error_message)):
ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5})
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory):
"""Test exception with empty response text."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_response = MagicMock()
mock_response.status_code = 503
mock_response.text = ""
mock_process.return_value = mock_response
# Act & Assert
with pytest.raises(Exception, match=""):
ExternalDatasetService.fetch_external_knowledge_retrieval(
"tenant-123", "dataset-123", "query", {"top_k": 5}
)

View File

@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
from tasks.document_indexing_task import (
_document_indexing,
_document_indexing_with_tenant_queue,
@ -138,7 +138,9 @@ class TestTaskEnqueuing:
with patch.object(DocumentIndexingTaskProxy, "features") as mock_features:
mock_features.billing.enabled = False
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -163,7 +165,9 @@ class TestTaskEnqueuing:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -187,7 +191,9 @@ class TestTaskEnqueuing:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -211,7 +217,9 @@ class TestTaskEnqueuing:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Act
@ -1493,7 +1501,9 @@ class TestEdgeCases:
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
# Mock the class variable directly
mock_task = Mock()
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
# Act - Enqueue multiple tasks rapidly
for doc_ids in document_ids_list:
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids)
@ -1898,7 +1908,7 @@ class TestRobustness:
- Error is propagated appropriately
"""
# Arrange
with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features:
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features:
# Simulate FeatureService failure
mock_get_features.side_effect = Exception("Feature service unavailable")

Some files were not shown because too many files have changed in this diff Show More