mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 23:38:53 +08:00
Merge branch 'main' into feat/end-user-oauth
# Conflicts: # web/app/components/app/configuration/config/agent/agent-tools/index.tsx
This commit is contained in:
commit
2ea07cd8f8
@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1
|
|||||||
|
|
||||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||||
|
|
||||||
|
# Multimodal knowledgebase limit
|
||||||
|
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
|
||||||
|
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
|
||||||
|
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
|
||||||
|
IMAGE_FILE_BATCH_LIMIT=10
|
||||||
|
|||||||
@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings):
|
|||||||
default=10,
|
default=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
|
||||||
|
description="Maximum number of files allowed in a image batch upload operation",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
|
||||||
|
description="Maximum number of files allowed in a single chunk attachment",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="Maximum allowed image file size for attachments in megabytes",
|
||||||
|
default=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
|
||||||
|
description="Timeout for downloading image attachments in seconds",
|
||||||
|
default=60,
|
||||||
|
)
|
||||||
|
|
||||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||||
description=(
|
description=(
|
||||||
"Comma-separated list of file extensions that are blocked from upload. "
|
"Comma-separated list of file extensions that are blocked from upload. "
|
||||||
|
|||||||
@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
|
|||||||
class MessageFeedbackPayload(BaseModel):
|
class MessageFeedbackPayload(BaseModel):
|
||||||
message_id: str = Field(..., description="Message ID")
|
message_id: str = Field(..., description="Message ID")
|
||||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||||
|
content: str | None = Field(default=None, description="Feedback content")
|
||||||
|
|
||||||
@field_validator("message_id")
|
@field_validator("message_id")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
|
|||||||
db.session.delete(feedback)
|
db.session.delete(feedback)
|
||||||
elif args.rating and feedback:
|
elif args.rating and feedback:
|
||||||
feedback.rating = args.rating
|
feedback.rating = args.rating
|
||||||
|
feedback.content = args.content
|
||||||
elif not args.rating and not feedback:
|
elif not args.rating and not feedback:
|
||||||
raise ValueError("rating cannot be None when feedback not exists")
|
raise ValueError("rating cannot be None when feedback not exists")
|
||||||
else:
|
else:
|
||||||
@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
|
|||||||
conversation_id=message.conversation_id,
|
conversation_id=message.conversation_id,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
rating=rating_value,
|
rating=rating_value,
|
||||||
|
content=args.content,
|
||||||
from_source="admin",
|
from_source="admin",
|
||||||
from_account_id=current_user.id,
|
from_account_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel):
|
|||||||
external_knowledge_id: str | None = None
|
external_knowledge_id: str | None = None
|
||||||
external_knowledge_api_id: str | None = None
|
external_knowledge_api_id: str | None = None
|
||||||
icon_info: dict[str, Any] | None = None
|
icon_info: dict[str, Any] | None = None
|
||||||
|
is_multimodal: bool | None = False
|
||||||
|
|
||||||
@field_validator("indexing_technique")
|
@field_validator("indexing_technique")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -423,17 +424,16 @@ class DatasetApi(Resource):
|
|||||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
payload_data = payload.model_dump(exclude_unset=True)
|
payload_data = payload.model_dump(exclude_unset=True)
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if (
|
if (
|
||||||
payload.indexing_technique == "high_quality"
|
payload.indexing_technique == "high_quality"
|
||||||
and payload.embedding_model_provider is not None
|
and payload.embedding_model_provider is not None
|
||||||
and payload.embedding_model is not None
|
and payload.embedding_model is not None
|
||||||
):
|
):
|
||||||
DatasetService.check_embedding_model_setting(
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||||
)
|
)
|
||||||
|
payload.is_multimodal = is_multimodal
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
DatasetPermissionService.check_permission(
|
DatasetPermissionService.check_permission(
|
||||||
current_user, dataset, payload.permission, payload.partial_member_list
|
current_user, dataset, payload.permission, payload.partial_member_list
|
||||||
|
|||||||
@ -424,6 +424,10 @@ class DatasetInitApi(Resource):
|
|||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=knowledge_config.embedding_model,
|
model=knowledge_config.embedding_model,
|
||||||
)
|
)
|
||||||
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
|
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||||
|
)
|
||||||
|
knowledge_config.is_multimodal = is_multimodal
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||||
|
|||||||
@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel):
|
|||||||
content: str
|
content: str
|
||||||
answer: str | None = None
|
answer: str | None = None
|
||||||
keywords: list[str] | None = None
|
keywords: list[str] | None = None
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SegmentUpdatePayload(BaseModel):
|
class SegmentUpdatePayload(BaseModel):
|
||||||
@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel):
|
|||||||
answer: str | None = None
|
answer: str | None = None
|
||||||
keywords: list[str] | None = None
|
keywords: list[str] | None = None
|
||||||
regenerate_child_chunks: bool = False
|
regenerate_child_chunks: bool = False
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class BatchImportPayload(BaseModel):
|
class BatchImportPayload(BaseModel):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask_restx import marshal
|
from flask_restx import marshal, reqparse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel):
|
|||||||
query: str = Field(max_length=250)
|
query: str = Field(max_length=250)
|
||||||
retrieval_model: dict[str, Any] | None = None
|
retrieval_model: dict[str, Any] | None = None
|
||||||
external_retrieval_model: dict[str, Any] | None = None
|
external_retrieval_model: dict[str, Any] | None = None
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class DatasetsHitTestingBase:
|
class DatasetsHitTestingBase:
|
||||||
@ -54,16 +55,28 @@ class DatasetsHitTestingBase:
|
|||||||
def hit_testing_args_check(args: dict[str, Any]):
|
def hit_testing_args_check(args: dict[str, Any]):
|
||||||
HitTestingService.hit_testing_args_check(args)
|
HitTestingService.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_args():
|
||||||
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
|
.add_argument("query", type=str, required=False, location="json")
|
||||||
|
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||||
|
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def perform_hit_testing(dataset, args):
|
def perform_hit_testing(dataset, args):
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
try:
|
try:
|
||||||
response = HitTestingService.retrieve(
|
response = HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=args["query"],
|
query=args.get("query"),
|
||||||
account=current_user,
|
account=current_user,
|
||||||
retrieval_model=args["retrieval_model"],
|
retrieval_model=args.get("retrieval_model"),
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
external_retrieval_model=args.get("external_retrieval_model"),
|
||||||
|
attachment_ids=args.get("attachment_ids"),
|
||||||
limit=10,
|
limit=10,
|
||||||
)
|
)
|
||||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||||
|
|||||||
@ -45,6 +45,9 @@ class FileApi(Resource):
|
|||||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||||
|
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||||
|
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||||
|
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|||||||
@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri
|
|||||||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||||
|
|
||||||
from .. import console_ns
|
from .. import console_ns
|
||||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
from ..wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
edit_permission_required,
|
||||||
|
is_admin_or_owner_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource):
|
|||||||
class TriggerSubscriptionListApi(Resource):
|
class TriggerSubscriptionListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@is_admin_or_owner_required
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
"""List all trigger subscriptions for the current tenant's provider"""
|
"""List all trigger subscriptions for the current tenant's provider"""
|
||||||
@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||||||
@console_ns.expect(parser)
|
@console_ns.expect(parser)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@is_admin_or_owner_required
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
"""Add a new subscription instance for a trigger provider"""
|
"""Add a new subscription instance for a trigger provider"""
|
||||||
@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider, subscription_builder_id):
|
def get(self, provider, subscription_builder_id):
|
||||||
"""Get a subscription instance for a trigger provider"""
|
"""Get a subscription instance for a trigger provider"""
|
||||||
@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||||||
@console_ns.expect(parser_api)
|
@console_ns.expect(parser_api)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@is_admin_or_owner_required
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider, subscription_builder_id):
|
def post(self, provider, subscription_builder_id):
|
||||||
"""Verify a subscription instance for a trigger provider"""
|
"""Verify a subscription instance for a trigger provider"""
|
||||||
@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||||||
@console_ns.expect(parser_update_api)
|
@console_ns.expect(parser_update_api)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider, subscription_builder_id):
|
def post(self, provider, subscription_builder_id):
|
||||||
"""Update a subscription instance for a trigger provider"""
|
"""Update a subscription instance for a trigger provider"""
|
||||||
@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider, subscription_builder_id):
|
def get(self, provider, subscription_builder_id):
|
||||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||||
@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||||||
@console_ns.expect(parser_update_api)
|
@console_ns.expect(parser_update_api)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@is_admin_or_owner_required
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider, subscription_builder_id):
|
def post(self, provider, subscription_builder_id):
|
||||||
"""Build a subscription instance for a trigger provider"""
|
"""Build a subscription instance for a trigger provider"""
|
||||||
|
|||||||
@ -83,6 +83,7 @@ class AppRunner:
|
|||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
|
context_files: list["File"] | None = None,
|
||||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||||
"""
|
"""
|
||||||
Organize prompt messages
|
Organize prompt messages
|
||||||
@ -111,6 +112,7 @@ class AppRunner:
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
image_detail_config=image_detail_config,
|
image_detail_config=image_detail_config,
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.file import File
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner):
|
|||||||
|
|
||||||
# get context from datasets
|
# get context from datasets
|
||||||
context = None
|
context = None
|
||||||
|
context_files: list[File] = []
|
||||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||||
hit_callback = DatasetIndexToolCallbackHandler(
|
hit_callback = DatasetIndexToolCallbackHandler(
|
||||||
queue_manager,
|
queue_manager,
|
||||||
@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||||
context = dataset_retrieval.retrieve(
|
context, retrieved_files = dataset_retrieval.retrieve(
|
||||||
app_id=app_record.id,
|
app_id=app_record.id,
|
||||||
user_id=application_generate_entity.user_id,
|
user_id=application_generate_entity.user_id,
|
||||||
tenant_id=app_record.tenant_id,
|
tenant_id=app_record.tenant_id,
|
||||||
@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner):
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||||
|
"enabled", False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
context_files = retrieved_files or []
|
||||||
|
|
||||||
# reorganize all inputs and template to prompt messages
|
# reorganize all inputs and template to prompt messages
|
||||||
# Include: prompt template, inputs, query(optional), files(optional)
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner):
|
|||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
image_detail_config=image_detail_config,
|
image_detail_config=image_detail_config,
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check hosting moderation
|
# check hosting moderation
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
CompletionAppGenerateEntity,
|
CompletionAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.file import File
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner):
|
|||||||
|
|
||||||
# get context from datasets
|
# get context from datasets
|
||||||
context = None
|
context = None
|
||||||
|
context_files: list[File] = []
|
||||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||||
hit_callback = DatasetIndexToolCallbackHandler(
|
hit_callback = DatasetIndexToolCallbackHandler(
|
||||||
queue_manager,
|
queue_manager,
|
||||||
@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner):
|
|||||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||||
|
|
||||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||||
context = dataset_retrieval.retrieve(
|
context, retrieved_files = dataset_retrieval.retrieve(
|
||||||
app_id=app_record.id,
|
app_id=app_record.id,
|
||||||
user_id=application_generate_entity.user_id,
|
user_id=application_generate_entity.user_id,
|
||||||
tenant_id=app_record.tenant_id,
|
tenant_id=app_record.tenant_id,
|
||||||
@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner):
|
|||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||||
|
"enabled", False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
context_files = retrieved_files or []
|
||||||
|
|
||||||
# reorganize all inputs and template to prompt messages
|
# reorganize all inputs and template to prompt messages
|
||||||
# Include: prompt template, inputs, query(optional), files(optional)
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner):
|
|||||||
query=query,
|
query=query,
|
||||||
context=context,
|
context=context,
|
||||||
image_detail_config=image_detail_config,
|
image_detail_config=image_detail_config,
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check hosting moderation
|
# check hosting moderation
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||||
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
document_id,
|
document_id,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunk_stmt = select(ChildChunk).where(
|
child_chunk_stmt = select(ChildChunk).where(
|
||||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
|||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import ChildDocument, Document
|
||||||
@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client
|
|||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models import Account
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
@ -89,8 +90,17 @@ class IndexingRunner:
|
|||||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||||
|
|
||||||
# transform
|
# transform
|
||||||
|
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
|
||||||
|
if not current_user:
|
||||||
|
raise ValueError("no current user found")
|
||||||
|
current_user.set_tenant_id(dataset.tenant_id)
|
||||||
documents = self._transform(
|
documents = self._transform(
|
||||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
index_processor,
|
||||||
|
dataset,
|
||||||
|
text_docs,
|
||||||
|
requeried_document.doc_language,
|
||||||
|
processing_rule.to_dict(),
|
||||||
|
current_user=current_user,
|
||||||
)
|
)
|
||||||
# save segment
|
# save segment
|
||||||
self._load_segments(dataset, requeried_document, documents)
|
self._load_segments(dataset, requeried_document, documents)
|
||||||
@ -136,7 +146,7 @@ class IndexingRunner:
|
|||||||
|
|
||||||
for document_segment in document_segments:
|
for document_segment in document_segments:
|
||||||
db.session.delete(document_segment)
|
db.session.delete(document_segment)
|
||||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
# delete child chunks
|
# delete child chunks
|
||||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -152,8 +162,17 @@ class IndexingRunner:
|
|||||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||||
|
|
||||||
# transform
|
# transform
|
||||||
|
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
|
||||||
|
if not current_user:
|
||||||
|
raise ValueError("no current user found")
|
||||||
|
current_user.set_tenant_id(dataset.tenant_id)
|
||||||
documents = self._transform(
|
documents = self._transform(
|
||||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
index_processor,
|
||||||
|
dataset,
|
||||||
|
text_docs,
|
||||||
|
requeried_document.doc_language,
|
||||||
|
processing_rule.to_dict(),
|
||||||
|
current_user=current_user,
|
||||||
)
|
)
|
||||||
# save segment
|
# save segment
|
||||||
self._load_segments(dataset, requeried_document, documents)
|
self._load_segments(dataset, requeried_document, documents)
|
||||||
@ -209,7 +228,7 @@ class IndexingRunner:
|
|||||||
"dataset_id": document_segment.dataset_id,
|
"dataset_id": document_segment.dataset_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = document_segment.get_child_chunks()
|
child_chunks = document_segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -302,6 +321,7 @@ class IndexingRunner:
|
|||||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||||
documents = index_processor.transform(
|
documents = index_processor.transform(
|
||||||
text_docs,
|
text_docs,
|
||||||
|
current_user=None,
|
||||||
embedding_model_instance=embedding_model_instance,
|
embedding_model_instance=embedding_model_instance,
|
||||||
process_rule=processing_rule.to_dict(),
|
process_rule=processing_rule.to_dict(),
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -551,7 +571,10 @@ class IndexingRunner:
|
|||||||
indexing_start_at = time.perf_counter()
|
indexing_start_at = time.perf_counter()
|
||||||
tokens = 0
|
tokens = 0
|
||||||
create_keyword_thread = None
|
create_keyword_thread = None
|
||||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
if (
|
||||||
|
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||||
|
and dataset.indexing_technique == "economy"
|
||||||
|
):
|
||||||
# create keyword index
|
# create keyword index
|
||||||
create_keyword_thread = threading.Thread(
|
create_keyword_thread = threading.Thread(
|
||||||
target=self._process_keyword_index,
|
target=self._process_keyword_index,
|
||||||
@ -590,7 +613,7 @@ class IndexingRunner:
|
|||||||
for future in futures:
|
for future in futures:
|
||||||
tokens += future.result()
|
tokens += future.result()
|
||||||
if (
|
if (
|
||||||
dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX
|
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||||
and dataset.indexing_technique == "economy"
|
and dataset.indexing_technique == "economy"
|
||||||
and create_keyword_thread is not None
|
and create_keyword_thread is not None
|
||||||
):
|
):
|
||||||
@ -635,7 +658,13 @@ class IndexingRunner:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def _process_chunk(
|
def _process_chunk(
|
||||||
self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
index_processor: BaseIndexProcessor,
|
||||||
|
chunk_documents: list[Document],
|
||||||
|
dataset: Dataset,
|
||||||
|
dataset_document: DatasetDocument,
|
||||||
|
embedding_model_instance: ModelInstance | None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
# check document is paused
|
# check document is paused
|
||||||
@ -646,8 +675,15 @@ class IndexingRunner:
|
|||||||
page_content_list = [document.page_content for document in chunk_documents]
|
page_content_list = [document.page_content for document in chunk_documents]
|
||||||
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
|
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
|
||||||
|
|
||||||
|
multimodal_documents = []
|
||||||
|
for document in chunk_documents:
|
||||||
|
if document.attachments and dataset.is_multimodal:
|
||||||
|
multimodal_documents.extend(document.attachments)
|
||||||
|
|
||||||
# load index
|
# load index
|
||||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
index_processor.load(
|
||||||
|
dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||||
|
)
|
||||||
|
|
||||||
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
|
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
|
||||||
db.session.query(DocumentSegment).where(
|
db.session.query(DocumentSegment).where(
|
||||||
@ -710,6 +746,7 @@ class IndexingRunner:
|
|||||||
text_docs: list[Document],
|
text_docs: list[Document],
|
||||||
doc_language: str,
|
doc_language: str,
|
||||||
process_rule: dict,
|
process_rule: dict,
|
||||||
|
current_user: Account | None = None,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
# get embedding model instance
|
# get embedding model instance
|
||||||
embedding_model_instance = None
|
embedding_model_instance = None
|
||||||
@ -729,6 +766,7 @@ class IndexingRunner:
|
|||||||
|
|
||||||
documents = index_processor.transform(
|
documents = index_processor.transform(
|
||||||
text_docs,
|
text_docs,
|
||||||
|
current_user,
|
||||||
embedding_model_instance=embedding_model_instance,
|
embedding_model_instance=embedding_model_instance,
|
||||||
process_rule=process_rule,
|
process_rule=process_rule,
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
@ -737,14 +775,16 @@ class IndexingRunner:
|
|||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def _load_segments(self, dataset, dataset_document, documents):
|
def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]):
|
||||||
# save node to document segment
|
# save node to document segment
|
||||||
doc_store = DatasetDocumentStore(
|
doc_store = DatasetDocumentStore(
|
||||||
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
|
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
|
||||||
)
|
)
|
||||||
|
|
||||||
# add document segments
|
# add document segments
|
||||||
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
|
doc_store.add_documents(
|
||||||
|
docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||||
|
)
|
||||||
|
|
||||||
# update document status to indexing
|
# update document status to indexing
|
||||||
cur_time = naive_utc_now()
|
cur_time = naive_utc_now()
|
||||||
|
|||||||
@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError
|
|||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||||
@ -200,7 +200,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
def invoke_text_embedding(
|
def invoke_text_embedding(
|
||||||
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||||
) -> TextEmbeddingResult:
|
) -> EmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ class ModelInstance:
|
|||||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||||
return cast(
|
return cast(
|
||||||
TextEmbeddingResult,
|
EmbeddingResult,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -223,6 +223,34 @@ class ModelInstance:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def invoke_multimodal_embedding(
|
||||||
|
self,
|
||||||
|
multimodel_documents: list[dict],
|
||||||
|
user: str | None = None,
|
||||||
|
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||||
|
) -> EmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param multimodel_documents: multimodel documents to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:param input_type: input type
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||||
|
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||||
|
return cast(
|
||||||
|
EmbeddingResult,
|
||||||
|
self._round_robin_invoke(
|
||||||
|
function=self.model_type_instance.invoke,
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
multimodel_documents=multimodel_documents,
|
||||||
|
user=user,
|
||||||
|
input_type=input_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for text embedding
|
Get number of tokens for text embedding
|
||||||
@ -276,6 +304,40 @@ class ModelInstance:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def invoke_multimodal_rerank(
|
||||||
|
self,
|
||||||
|
query: dict,
|
||||||
|
docs: list[dict],
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
top_n: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model_type_instance, RerankModel):
|
||||||
|
raise Exception("Model type instance is not RerankModel")
|
||||||
|
return cast(
|
||||||
|
RerankResult,
|
||||||
|
self._round_robin_invoke(
|
||||||
|
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||||
|
model=self.model,
|
||||||
|
credentials=self.credentials,
|
||||||
|
query=query,
|
||||||
|
docs=docs,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=top_n,
|
||||||
|
user=user,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Invoke moderation model
|
Invoke moderation model
|
||||||
@ -461,6 +523,32 @@ class ModelManager:
|
|||||||
model=default_model_entity.model,
|
model=default_model_entity.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model supports vision
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param provider: provider name
|
||||||
|
:param model: model name
|
||||||
|
:return: True if model supports vision, False otherwise
|
||||||
|
"""
|
||||||
|
model_instance = self.get_model_instance(tenant_id, provider, model_type, model)
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
match model_type:
|
||||||
|
case ModelType.LLM:
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
case ModelType.TEXT_EMBEDDING:
|
||||||
|
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
|
||||||
|
case ModelType.RERANK:
|
||||||
|
model_type_instance = cast(RerankModel, model_type_instance)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Model type {model_type} is not supported")
|
||||||
|
model_schema = model_type_instance.get_model_schema(model, model_instance.credentials)
|
||||||
|
if not model_schema:
|
||||||
|
return False
|
||||||
|
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LBModelManager:
|
class LBModelManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage):
|
|||||||
latency: float
|
latency: float
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingResult(BaseModel):
|
class EmbeddingResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for text embedding result.
|
Model class for text embedding result.
|
||||||
"""
|
"""
|
||||||
@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
embeddings: list[list[float]]
|
embeddings: list[list[float]]
|
||||||
usage: EmbeddingUsage
|
usage: EmbeddingUsage
|
||||||
|
|
||||||
|
|
||||||
|
class FileEmbeddingResult(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for file embedding result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
embeddings: list[list[float]]
|
||||||
|
usage: EmbeddingUsage
|
||||||
|
|||||||
@ -50,3 +50,43 @@ class RerankModel(AIModel):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
def invoke_multimodal_rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: dict,
|
||||||
|
docs: list[dict],
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
top_n: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke multimodal rerank model
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
|
plugin_model_manager = PluginModelClient()
|
||||||
|
return plugin_model_manager.invoke_multimodal_rerank(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user or "unknown",
|
||||||
|
plugin_id=self.plugin_id,
|
||||||
|
provider=self.provider_name,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query=query,
|
||||||
|
docs=docs,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=top_n,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._transform_invoke_error(e)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from pydantic import ConfigDict
|
|||||||
|
|
||||||
from core.entities.embedding_type import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
|
|
||||||
@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel):
|
|||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
texts: list[str],
|
texts: list[str] | None = None,
|
||||||
|
multimodel_documents: list[dict] | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||||
) -> TextEmbeddingResult:
|
) -> EmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:param texts: texts to embed
|
:param texts: texts to embed
|
||||||
|
:param files: files to embed
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:param input_type: input type
|
:param input_type: input type
|
||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.invoke_text_embedding(
|
if texts:
|
||||||
tenant_id=self.tenant_id,
|
return plugin_model_manager.invoke_text_embedding(
|
||||||
user_id=user or "unknown",
|
tenant_id=self.tenant_id,
|
||||||
plugin_id=self.plugin_id,
|
user_id=user or "unknown",
|
||||||
provider=self.provider_name,
|
plugin_id=self.plugin_id,
|
||||||
model=model,
|
provider=self.provider_name,
|
||||||
credentials=credentials,
|
model=model,
|
||||||
texts=texts,
|
credentials=credentials,
|
||||||
input_type=input_type,
|
texts=texts,
|
||||||
)
|
input_type=input_type,
|
||||||
|
)
|
||||||
|
if multimodel_documents:
|
||||||
|
return plugin_model_manager.invoke_multimodal_embedding(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user or "unknown",
|
||||||
|
plugin_id=self.plugin_id,
|
||||||
|
provider=self.provider_name,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
documents=multimodel_documents,
|
||||||
|
input_type=input_type,
|
||||||
|
)
|
||||||
|
raise ValueError("No texts or files provided")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
|||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginBasicBooleanResponse,
|
PluginBasicBooleanResponse,
|
||||||
@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient):
|
|||||||
credentials: dict,
|
credentials: dict,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
input_type: str,
|
input_type: str,
|
||||||
) -> TextEmbeddingResult:
|
) -> EmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding
|
Invoke text embedding
|
||||||
"""
|
"""
|
||||||
response = self._request_with_plugin_daemon_response_stream(
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
method="POST",
|
method="POST",
|
||||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||||
type_=TextEmbeddingResult,
|
type_=EmbeddingResult,
|
||||||
data=jsonable_encoder(
|
data=jsonable_encoder(
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient):
|
|||||||
|
|
||||||
raise ValueError("Failed to invoke text embedding")
|
raise ValueError("Failed to invoke text embedding")
|
||||||
|
|
||||||
|
def invoke_multimodal_embedding(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
documents: list[dict],
|
||||||
|
input_type: str,
|
||||||
|
) -> EmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke file embedding
|
||||||
|
"""
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
method="POST",
|
||||||
|
path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke",
|
||||||
|
type_=EmbeddingResult,
|
||||||
|
data=jsonable_encoder(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"model_type": "text-embedding",
|
||||||
|
"model": model,
|
||||||
|
"credentials": credentials,
|
||||||
|
"documents": documents,
|
||||||
|
"input_type": input_type,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
raise ValueError("Failed to invoke file embedding")
|
||||||
|
|
||||||
def get_text_embedding_num_tokens(
|
def get_text_embedding_num_tokens(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient):
|
|||||||
|
|
||||||
raise ValueError("Failed to invoke rerank")
|
raise ValueError("Failed to invoke rerank")
|
||||||
|
|
||||||
|
def invoke_multimodal_rerank(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: dict,
|
||||||
|
docs: list[dict],
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
top_n: int | None = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke multimodal rerank
|
||||||
|
"""
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
method="POST",
|
||||||
|
path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke",
|
||||||
|
type_=RerankResult,
|
||||||
|
data=jsonable_encoder(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"model_type": "rerank",
|
||||||
|
"model": model,
|
||||||
|
"credentials": credentials,
|
||||||
|
"query": query,
|
||||||
|
"docs": docs,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"top_n": top_n,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
raise ValueError("Failed to invoke multimodal rerank")
|
||||||
|
|
||||||
def invoke_tts(
|
def invoke_tts(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
|
context_files: list["File"] | None = None,
|
||||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||||
inputs = {key: str(value) for key, value in inputs.items()}
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
image_detail_config=image_detail_config,
|
image_detail_config=image_detail_config,
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||||
@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
image_detail_config=image_detail_config,
|
image_detail_config=image_detail_config,
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_messages, stops
|
return prompt_messages, stops
|
||||||
@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
|
context_files: list["File"] | None = None,
|
||||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
||||||
@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
|
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files))
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
|
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files))
|
||||||
|
|
||||||
return prompt_messages, None
|
return prompt_messages, None
|
||||||
|
|
||||||
@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
|
context_files: list["File"] | None = None,
|
||||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||||
# get prompt
|
# get prompt
|
||||||
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
||||||
@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
if stops is not None and len(stops) == 0:
|
if stops is not None and len(stops) == 0:
|
||||||
stops = None
|
stops = None
|
||||||
|
|
||||||
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
|
return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops
|
||||||
|
|
||||||
def _get_last_user_message(
|
def _get_last_user_message(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
files: Sequence["File"],
|
files: Sequence["File"],
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
|
context_files: list["File"] | None = None,
|
||||||
) -> UserPromptMessage:
|
) -> UserPromptMessage:
|
||||||
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(
|
prompt_message_contents.append(
|
||||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||||
)
|
)
|
||||||
|
if context_files:
|
||||||
|
for file in context_files:
|
||||||
|
prompt_message_contents.append(
|
||||||
|
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||||
|
)
|
||||||
|
if prompt_message_contents:
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
|
|
||||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
||||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||||
@ -30,9 +31,10 @@ class DataPostProcessor:
|
|||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
top_n: int | None = None,
|
top_n: int | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.rerank_runner:
|
if self.rerank_runner:
|
||||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
|
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
|
||||||
|
|
||||||
if self.reorder_runner:
|
if self.reorder_runner:
|
||||||
documents = self.reorder_runner.run(documents)
|
documents = self.reorder_runner.run(documents)
|
||||||
|
|||||||
@ -1,23 +1,30 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, load_only
|
from sqlalchemy.orm import Session, load_only
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import RetrievalSegments
|
from core.rag.embedding.retrieval import RetrievalSegments
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
from core.rag.entities.metadata_entities import MetadataCondition
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.rerank_type import RerankMode
|
from core.rag.rerank.rerank_type import RerankMode
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.tools.signature import sign_upload_file
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from models.model import UploadFile
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
|
|
||||||
default_retrieval_model = {
|
default_retrieval_model = {
|
||||||
@ -37,14 +44,15 @@ class RetrievalService:
|
|||||||
retrieval_method: RetrievalMethod,
|
retrieval_method: RetrievalMethod,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int = 4,
|
||||||
score_threshold: float | None = 0.0,
|
score_threshold: float | None = 0.0,
|
||||||
reranking_model: dict | None = None,
|
reranking_model: dict | None = None,
|
||||||
reranking_mode: str = "reranking_model",
|
reranking_mode: str = "reranking_model",
|
||||||
weights: dict | None = None,
|
weights: dict | None = None,
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
|
attachment_ids: list | None = None,
|
||||||
):
|
):
|
||||||
if not query:
|
if not query and not attachment_ids:
|
||||||
return []
|
return []
|
||||||
dataset = cls._get_dataset(dataset_id)
|
dataset = cls._get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -56,69 +64,52 @@ class RetrievalService:
|
|||||||
# Optimize multithreading with thread pools
|
# Optimize multithreading with thread pools
|
||||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||||
futures = []
|
futures = []
|
||||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
retrieval_service = RetrievalService()
|
||||||
|
if query:
|
||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
cls.keyword_search,
|
retrieval_service._retrieve,
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
dataset_id=dataset_id,
|
retrieval_method=retrieval_method,
|
||||||
query=query,
|
dataset=dataset,
|
||||||
top_k=top_k,
|
|
||||||
all_documents=all_documents,
|
|
||||||
exceptions=exceptions,
|
|
||||||
document_ids_filter=document_ids_filter,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
|
||||||
futures.append(
|
|
||||||
executor.submit(
|
|
||||||
cls.embedding_search,
|
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
query=query,
|
query=query,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
reranking_model=reranking_model,
|
reranking_model=reranking_model,
|
||||||
all_documents=all_documents,
|
reranking_mode=reranking_mode,
|
||||||
retrieval_method=retrieval_method,
|
weights=weights,
|
||||||
exceptions=exceptions,
|
|
||||||
document_ids_filter=document_ids_filter,
|
document_ids_filter=document_ids_filter,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=all_documents,
|
||||||
|
exceptions=exceptions,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
if attachment_ids:
|
||||||
futures.append(
|
for attachment_id in attachment_ids:
|
||||||
executor.submit(
|
futures.append(
|
||||||
cls.full_text_index_search,
|
executor.submit(
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
retrieval_service._retrieve,
|
||||||
dataset_id=dataset_id,
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
query=query,
|
retrieval_method=retrieval_method,
|
||||||
top_k=top_k,
|
dataset=dataset,
|
||||||
score_threshold=score_threshold,
|
query=None,
|
||||||
reranking_model=reranking_model,
|
top_k=top_k,
|
||||||
all_documents=all_documents,
|
score_threshold=score_threshold,
|
||||||
retrieval_method=retrieval_method,
|
reranking_model=reranking_model,
|
||||||
exceptions=exceptions,
|
reranking_mode=reranking_mode,
|
||||||
document_ids_filter=document_ids_filter,
|
weights=weights,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
attachment_id=attachment_id,
|
||||||
|
all_documents=all_documents,
|
||||||
|
exceptions=exceptions,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
|
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
|
||||||
|
|
||||||
if exceptions:
|
if exceptions:
|
||||||
raise ValueError(";\n".join(exceptions))
|
raise ValueError(";\n".join(exceptions))
|
||||||
|
|
||||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
|
||||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
|
|
||||||
all_documents = cls._deduplicate_documents(all_documents)
|
|
||||||
data_post_processor = DataPostProcessor(
|
|
||||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
|
||||||
)
|
|
||||||
all_documents = data_post_processor.invoke(
|
|
||||||
query=query,
|
|
||||||
documents=all_documents,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
top_n=top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -223,6 +214,7 @@ class RetrievalService:
|
|||||||
retrieval_method: RetrievalMethod,
|
retrieval_method: RetrievalMethod,
|
||||||
exceptions: list,
|
exceptions: list,
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
@ -231,14 +223,30 @@ class RetrievalService:
|
|||||||
raise ValueError("dataset not found")
|
raise ValueError("dataset not found")
|
||||||
|
|
||||||
vector = Vector(dataset=dataset)
|
vector = Vector(dataset=dataset)
|
||||||
documents = vector.search_by_vector(
|
documents = []
|
||||||
query,
|
if query_type == QueryType.TEXT_QUERY:
|
||||||
search_type="similarity_score_threshold",
|
documents.extend(
|
||||||
top_k=top_k,
|
vector.search_by_vector(
|
||||||
score_threshold=score_threshold,
|
query,
|
||||||
filter={"group_id": [dataset.id]},
|
search_type="similarity_score_threshold",
|
||||||
document_ids_filter=document_ids_filter,
|
top_k=top_k,
|
||||||
)
|
score_threshold=score_threshold,
|
||||||
|
filter={"group_id": [dataset.id]},
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if query_type == QueryType.IMAGE_QUERY:
|
||||||
|
if not dataset.is_multimodal:
|
||||||
|
return
|
||||||
|
documents.extend(
|
||||||
|
vector.search_by_file(
|
||||||
|
file_id=query,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
filter={"group_id": [dataset.id]},
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if documents:
|
if documents:
|
||||||
if (
|
if (
|
||||||
@ -250,14 +258,37 @@ class RetrievalService:
|
|||||||
data_post_processor = DataPostProcessor(
|
data_post_processor = DataPostProcessor(
|
||||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||||
)
|
)
|
||||||
all_documents.extend(
|
if dataset.is_multimodal:
|
||||||
data_post_processor.invoke(
|
model_manager = ModelManager()
|
||||||
query=query,
|
is_support_vision = model_manager.check_model_support_vision(
|
||||||
documents=documents,
|
tenant_id=dataset.tenant_id,
|
||||||
score_threshold=score_threshold,
|
provider=reranking_model.get("reranking_provider_name") or "",
|
||||||
top_n=len(documents),
|
model=reranking_model.get("reranking_model_name") or "",
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
)
|
||||||
|
if is_support_vision:
|
||||||
|
all_documents.extend(
|
||||||
|
data_post_processor.invoke(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=len(documents),
|
||||||
|
query_type=query_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not effective, return original documents
|
||||||
|
all_documents.extend(documents)
|
||||||
|
else:
|
||||||
|
all_documents.extend(
|
||||||
|
data_post_processor.invoke(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=len(documents),
|
||||||
|
query_type=query_type,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -339,103 +370,159 @@ class RetrievalService:
|
|||||||
records = []
|
records = []
|
||||||
include_segment_ids = set()
|
include_segment_ids = set()
|
||||||
segment_child_map = {}
|
segment_child_map = {}
|
||||||
|
segment_file_map = {}
|
||||||
# Process documents
|
with Session(db.engine) as session:
|
||||||
for document in documents:
|
# Process documents
|
||||||
document_id = document.metadata.get("document_id")
|
for document in documents:
|
||||||
if document_id not in dataset_documents:
|
segment_id = None
|
||||||
continue
|
attachment_info = None
|
||||||
|
child_chunk = None
|
||||||
dataset_document = dataset_documents[document_id]
|
document_id = document.metadata.get("document_id")
|
||||||
if not dataset_document:
|
if document_id not in dataset_documents:
|
||||||
continue
|
|
||||||
|
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
|
||||||
# Handle parent-child documents
|
|
||||||
child_index_node_id = document.metadata.get("doc_id")
|
|
||||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
|
||||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
|
||||||
|
|
||||||
if not child_chunk:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
segment = (
|
dataset_document = dataset_documents[document_id]
|
||||||
db.session.query(DocumentSegment)
|
if not dataset_document:
|
||||||
.where(
|
continue
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
|
||||||
DocumentSegment.enabled == True,
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
DocumentSegment.status == "completed",
|
# Handle parent-child documents
|
||||||
DocumentSegment.id == child_chunk.segment_id,
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
)
|
attachment_info_dict = cls.get_segment_attachment_info(
|
||||||
.options(
|
dataset_document.dataset_id,
|
||||||
load_only(
|
dataset_document.tenant_id,
|
||||||
DocumentSegment.id,
|
document.metadata.get("doc_id") or "",
|
||||||
DocumentSegment.content,
|
session,
|
||||||
DocumentSegment.answer,
|
|
||||||
)
|
)
|
||||||
|
if attachment_info_dict:
|
||||||
|
attachment_info = attachment_info_dict["attchment_info"]
|
||||||
|
segment_id = attachment_info_dict["segment_id"]
|
||||||
|
else:
|
||||||
|
child_index_node_id = document.metadata.get("doc_id")
|
||||||
|
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||||
|
child_chunk = session.scalar(child_chunk_stmt)
|
||||||
|
|
||||||
|
if not child_chunk:
|
||||||
|
continue
|
||||||
|
segment_id = child_chunk.segment_id
|
||||||
|
|
||||||
|
if not segment_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
segment = (
|
||||||
|
session.query(DocumentSegment)
|
||||||
|
.where(
|
||||||
|
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.status == "completed",
|
||||||
|
DocumentSegment.id == segment_id,
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
load_only(
|
||||||
|
DocumentSegment.id,
|
||||||
|
DocumentSegment.content,
|
||||||
|
DocumentSegment.answer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not segment:
|
if not segment:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
child_chunk_detail = {
|
if child_chunk:
|
||||||
"id": child_chunk.id,
|
child_chunk_detail = {
|
||||||
"content": child_chunk.content,
|
"id": child_chunk.id,
|
||||||
"position": child_chunk.position,
|
"content": child_chunk.content,
|
||||||
"score": document.metadata.get("score", 0.0),
|
"position": child_chunk.position,
|
||||||
}
|
"score": document.metadata.get("score", 0.0),
|
||||||
map_detail = {
|
}
|
||||||
"max_score": document.metadata.get("score", 0.0),
|
map_detail = {
|
||||||
"child_chunks": [child_chunk_detail],
|
"max_score": document.metadata.get("score", 0.0),
|
||||||
}
|
"child_chunks": [child_chunk_detail],
|
||||||
segment_child_map[segment.id] = map_detail
|
}
|
||||||
record = {
|
segment_child_map[segment.id] = map_detail
|
||||||
"segment": segment,
|
record = {
|
||||||
}
|
"segment": segment,
|
||||||
records.append(record)
|
}
|
||||||
|
if attachment_info:
|
||||||
|
segment_file_map[segment.id] = [attachment_info]
|
||||||
|
records.append(record)
|
||||||
|
else:
|
||||||
|
if child_chunk:
|
||||||
|
child_chunk_detail = {
|
||||||
|
"id": child_chunk.id,
|
||||||
|
"content": child_chunk.content,
|
||||||
|
"position": child_chunk.position,
|
||||||
|
"score": document.metadata.get("score", 0.0),
|
||||||
|
}
|
||||||
|
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||||
|
segment_child_map[segment.id]["max_score"] = max(
|
||||||
|
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||||
|
)
|
||||||
|
if attachment_info:
|
||||||
|
segment_file_map[segment.id].append(attachment_info)
|
||||||
else:
|
else:
|
||||||
child_chunk_detail = {
|
# Handle normal documents
|
||||||
"id": child_chunk.id,
|
segment = None
|
||||||
"content": child_chunk.content,
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
"position": child_chunk.position,
|
attachment_info_dict = cls.get_segment_attachment_info(
|
||||||
"score": document.metadata.get("score", 0.0),
|
dataset_document.dataset_id,
|
||||||
}
|
dataset_document.tenant_id,
|
||||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
document.metadata.get("doc_id") or "",
|
||||||
segment_child_map[segment.id]["max_score"] = max(
|
session,
|
||||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
)
|
||||||
)
|
if attachment_info_dict:
|
||||||
else:
|
attachment_info = attachment_info_dict["attchment_info"]
|
||||||
# Handle normal documents
|
segment_id = attachment_info_dict["segment_id"]
|
||||||
index_node_id = document.metadata.get("doc_id")
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
if not index_node_id:
|
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||||
continue
|
DocumentSegment.enabled == True,
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
DocumentSegment.status == "completed",
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
DocumentSegment.id == segment_id,
|
||||||
DocumentSegment.enabled == True,
|
)
|
||||||
DocumentSegment.status == "completed",
|
segment = db.session.scalar(document_segment_stmt)
|
||||||
DocumentSegment.index_node_id == index_node_id,
|
if segment:
|
||||||
)
|
segment_file_map[segment.id] = [attachment_info]
|
||||||
segment = db.session.scalar(document_segment_stmt)
|
else:
|
||||||
|
index_node_id = document.metadata.get("doc_id")
|
||||||
|
if not index_node_id:
|
||||||
|
continue
|
||||||
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
|
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.status == "completed",
|
||||||
|
DocumentSegment.index_node_id == index_node_id,
|
||||||
|
)
|
||||||
|
segment = db.session.scalar(document_segment_stmt)
|
||||||
|
|
||||||
if not segment:
|
if not segment:
|
||||||
continue
|
continue
|
||||||
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": document.metadata.get("score"), # type: ignore
|
"score": document.metadata.get("score"), # type: ignore
|
||||||
}
|
}
|
||||||
records.append(record)
|
if attachment_info:
|
||||||
|
segment_file_map[segment.id] = [attachment_info]
|
||||||
|
records.append(record)
|
||||||
|
else:
|
||||||
|
if attachment_info:
|
||||||
|
attachment_infos = segment_file_map.get(segment.id, [])
|
||||||
|
if attachment_info not in attachment_infos:
|
||||||
|
attachment_infos.append(attachment_info)
|
||||||
|
segment_file_map[segment.id] = attachment_infos
|
||||||
|
|
||||||
# Add child chunks information to records
|
# Add child chunks information to records
|
||||||
for record in records:
|
for record in records:
|
||||||
if record["segment"].id in segment_child_map:
|
if record["segment"].id in segment_child_map:
|
||||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||||
|
if record["segment"].id in segment_file_map:
|
||||||
|
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for record in records:
|
for record in records:
|
||||||
@ -447,6 +534,11 @@ class RetrievalService:
|
|||||||
if not isinstance(child_chunks, list):
|
if not isinstance(child_chunks, list):
|
||||||
child_chunks = None
|
child_chunks = None
|
||||||
|
|
||||||
|
# Extract files, ensuring it's a list or None
|
||||||
|
files = record.get("files")
|
||||||
|
if not isinstance(files, list):
|
||||||
|
files = None
|
||||||
|
|
||||||
# Extract score, ensuring it's a float or None
|
# Extract score, ensuring it's a float or None
|
||||||
score_value = record.get("score")
|
score_value = record.get("score")
|
||||||
score = (
|
score = (
|
||||||
@ -456,10 +548,149 @@ class RetrievalService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create RetrievalSegments object
|
# Create RetrievalSegments object
|
||||||
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
|
retrieval_segment = RetrievalSegments(
|
||||||
|
segment=segment, child_chunks=child_chunks, score=score, files=files
|
||||||
|
)
|
||||||
result.append(retrieval_segment)
|
result.append(retrieval_segment)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _retrieve(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
retrieval_method: RetrievalMethod,
|
||||||
|
dataset: Dataset,
|
||||||
|
query: str | None = None,
|
||||||
|
top_k: int = 4,
|
||||||
|
score_threshold: float | None = 0.0,
|
||||||
|
reranking_model: dict | None = None,
|
||||||
|
reranking_mode: str = "reranking_model",
|
||||||
|
weights: dict | None = None,
|
||||||
|
document_ids_filter: list[str] | None = None,
|
||||||
|
attachment_id: str | None = None,
|
||||||
|
all_documents: list[Document] = [],
|
||||||
|
exceptions: list[str] = [],
|
||||||
|
):
|
||||||
|
if not query and not attachment_id:
|
||||||
|
return
|
||||||
|
with flask_app.app_context():
|
||||||
|
all_documents_item: list[Document] = []
|
||||||
|
# Optimize multithreading with thread pools
|
||||||
|
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||||
|
futures = []
|
||||||
|
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
self.keyword_search,
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
all_documents=all_documents_item,
|
||||||
|
exceptions=exceptions,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||||
|
if query:
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
self.embedding_search,
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
reranking_model=reranking_model,
|
||||||
|
all_documents=all_documents_item,
|
||||||
|
retrieval_method=retrieval_method,
|
||||||
|
exceptions=exceptions,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
query_type=QueryType.TEXT_QUERY,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if attachment_id:
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
self.embedding_search,
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
query=attachment_id,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
reranking_model=reranking_model,
|
||||||
|
all_documents=all_documents_item,
|
||||||
|
retrieval_method=retrieval_method,
|
||||||
|
exceptions=exceptions,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
query_type=QueryType.IMAGE_QUERY,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
|
||||||
|
futures.append(
|
||||||
|
executor.submit(
|
||||||
|
self.full_text_index_search,
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
reranking_model=reranking_model,
|
||||||
|
all_documents=all_documents_item,
|
||||||
|
retrieval_method=retrieval_method,
|
||||||
|
exceptions=exceptions,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
|
||||||
|
|
||||||
|
if exceptions:
|
||||||
|
raise ValueError(";\n".join(exceptions))
|
||||||
|
|
||||||
|
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||||
|
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
|
||||||
|
if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||||
|
all_documents.extend(all_documents_item)
|
||||||
|
all_documents_item = self._deduplicate_documents(all_documents_item)
|
||||||
|
data_post_processor = DataPostProcessor(
|
||||||
|
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query or attachment_id
|
||||||
|
if not query:
|
||||||
|
return
|
||||||
|
all_documents_item = data_post_processor.invoke(
|
||||||
|
query=query,
|
||||||
|
documents=all_documents_item,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=top_k,
|
||||||
|
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_documents.extend(all_documents_item)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_segment_attachment_info(
|
||||||
|
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||||
|
if upload_file:
|
||||||
|
attachment_binding = (
|
||||||
|
session.query(SegmentAttachmentBinding)
|
||||||
|
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if attachment_binding:
|
||||||
|
attchment_info = {
|
||||||
|
"id": upload_file.id,
|
||||||
|
"name": upload_file.name,
|
||||||
|
"extension": "." + upload_file.extension,
|
||||||
|
"mime_type": upload_file.mime_type,
|
||||||
|
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||||
|
"size": upload_file.size,
|
||||||
|
}
|
||||||
|
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
|
||||||
|
return None
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector
|
|||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from extensions.ext_storage import storage
|
||||||
from models.dataset import Dataset, Whitelist
|
from models.dataset import Dataset, Whitelist
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -203,6 +207,47 @@ class Vector:
|
|||||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||||
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||||
|
|
||||||
|
def create_multimodal(self, file_documents: list | None = None, **kwargs):
|
||||||
|
if file_documents:
|
||||||
|
start = time.time()
|
||||||
|
logger.info("start embedding %s files %s", len(file_documents), start)
|
||||||
|
batch_size = 1000
|
||||||
|
total_batches = len(file_documents) + batch_size - 1
|
||||||
|
for i in range(0, len(file_documents), batch_size):
|
||||||
|
batch = file_documents[i : i + batch_size]
|
||||||
|
batch_start = time.time()
|
||||||
|
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
|
||||||
|
|
||||||
|
# Batch query all upload files to avoid N+1 queries
|
||||||
|
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
|
||||||
|
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||||
|
upload_files = db.session.scalars(stmt).all()
|
||||||
|
upload_file_map = {str(f.id): f for f in upload_files}
|
||||||
|
|
||||||
|
file_base64_list = []
|
||||||
|
real_batch = []
|
||||||
|
for document in batch:
|
||||||
|
attachment_id = document.metadata["doc_id"]
|
||||||
|
doc_type = document.metadata["doc_type"]
|
||||||
|
upload_file = upload_file_map.get(attachment_id)
|
||||||
|
if upload_file:
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
file_base64_str = base64.b64encode(blob).decode()
|
||||||
|
file_base64_list.append(
|
||||||
|
{
|
||||||
|
"content": file_base64_str,
|
||||||
|
"content_type": doc_type,
|
||||||
|
"file_id": attachment_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
real_batch.append(document)
|
||||||
|
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
|
||||||
|
logger.info(
|
||||||
|
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||||
|
)
|
||||||
|
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
|
||||||
|
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], **kwargs):
|
def add_texts(self, documents: list[Document], **kwargs):
|
||||||
if kwargs.get("duplicate_check", False):
|
if kwargs.get("duplicate_check", False):
|
||||||
documents = self._filter_duplicate_texts(documents)
|
documents = self._filter_duplicate_texts(documents)
|
||||||
@ -223,6 +268,22 @@ class Vector:
|
|||||||
query_vector = self._embeddings.embed_query(query)
|
query_vector = self._embeddings.embed_query(query)
|
||||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||||
|
|
||||||
|
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
|
||||||
|
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return []
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
file_base64_str = base64.b64encode(blob).decode()
|
||||||
|
multimodal_vector = self._embeddings.embed_multimodal_query(
|
||||||
|
{
|
||||||
|
"content": file_base64_str,
|
||||||
|
"content_type": DocType.IMAGE,
|
||||||
|
"file_id": file_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -79,6 +79,18 @@ class WeaviateVector(BaseVector):
|
|||||||
self._client = self._init_client(config)
|
self._client = self._init_client(config)
|
||||||
self._attributes = attributes
|
self._attributes = attributes
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
Destructor to properly close the Weaviate client connection.
|
||||||
|
Prevents connection leaks and resource warnings.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_client") and self._client is not None:
|
||||||
|
try:
|
||||||
|
self._client.close()
|
||||||
|
except Exception as e:
|
||||||
|
# Ignore errors during cleanup as object is being destroyed
|
||||||
|
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||||
|
|
||||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||||
"""
|
"""
|
||||||
Initializes and returns a connected Weaviate client.
|
Initializes and returns a connected Weaviate client.
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from sqlalchemy import func, select
|
|||||||
|
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import AttachmentDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||||
|
|
||||||
|
|
||||||
class DatasetDocumentStore:
|
class DatasetDocumentStore:
|
||||||
@ -120,6 +120,9 @@ class DatasetDocumentStore:
|
|||||||
|
|
||||||
db.session.add(segment_document)
|
db.session.add(segment_document)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
self.add_multimodel_documents_binding(
|
||||||
|
segment_id=segment_document.id, multimodel_documents=doc.attachments
|
||||||
|
)
|
||||||
if save_child:
|
if save_child:
|
||||||
if doc.children:
|
if doc.children:
|
||||||
for position, child in enumerate(doc.children, start=1):
|
for position, child in enumerate(doc.children, start=1):
|
||||||
@ -144,6 +147,9 @@ class DatasetDocumentStore:
|
|||||||
segment_document.index_node_hash = doc.metadata.get("doc_hash")
|
segment_document.index_node_hash = doc.metadata.get("doc_hash")
|
||||||
segment_document.word_count = len(doc.page_content)
|
segment_document.word_count = len(doc.page_content)
|
||||||
segment_document.tokens = tokens
|
segment_document.tokens = tokens
|
||||||
|
self.add_multimodel_documents_binding(
|
||||||
|
segment_id=segment_document.id, multimodel_documents=doc.attachments
|
||||||
|
)
|
||||||
if save_child and doc.children:
|
if save_child and doc.children:
|
||||||
# delete the existing child chunks
|
# delete the existing child chunks
|
||||||
db.session.query(ChildChunk).where(
|
db.session.query(ChildChunk).where(
|
||||||
@ -233,3 +239,15 @@ class DatasetDocumentStore:
|
|||||||
document_segment = db.session.scalar(stmt)
|
document_segment = db.session.scalar(stmt)
|
||||||
|
|
||||||
return document_segment
|
return document_segment
|
||||||
|
|
||||||
|
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
|
||||||
|
if multimodel_documents:
|
||||||
|
for multimodel_document in multimodel_documents:
|
||||||
|
binding = SegmentAttachmentBinding(
|
||||||
|
tenant_id=self._dataset.tenant_id,
|
||||||
|
dataset_id=self._dataset.id,
|
||||||
|
document_id=self._document_id,
|
||||||
|
segment_id=segment_id,
|
||||||
|
attachment_id=multimodel_document.metadata["doc_id"],
|
||||||
|
)
|
||||||
|
db.session.add(binding)
|
||||||
|
|||||||
@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings):
|
|||||||
|
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
|
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||||
|
"""Embed file documents."""
|
||||||
|
# use doc embedding cache or store if not exists
|
||||||
|
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
|
||||||
|
embedding_queue_indices = []
|
||||||
|
for i, multimodel_document in enumerate(multimodel_documents):
|
||||||
|
file_id = multimodel_document["file_id"]
|
||||||
|
embedding = (
|
||||||
|
db.session.query(Embedding)
|
||||||
|
.filter_by(
|
||||||
|
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if embedding:
|
||||||
|
multimodel_embeddings[i] = embedding.get_embedding()
|
||||||
|
else:
|
||||||
|
embedding_queue_indices.append(i)
|
||||||
|
|
||||||
|
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
|
||||||
|
|
||||||
|
if embedding_queue_indices:
|
||||||
|
embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices]
|
||||||
|
embedding_queue_embeddings = []
|
||||||
|
try:
|
||||||
|
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||||
|
model_schema = model_type_instance.get_model_schema(
|
||||||
|
self._model_instance.model, self._model_instance.credentials
|
||||||
|
)
|
||||||
|
max_chunks = (
|
||||||
|
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||||
|
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
for i in range(0, len(embedding_queue_multimodel_documents), max_chunks):
|
||||||
|
batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks]
|
||||||
|
|
||||||
|
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||||
|
multimodel_documents=batch_multimodel_documents,
|
||||||
|
user=self._user,
|
||||||
|
input_type=EmbeddingInputType.DOCUMENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector in embedding_result.embeddings:
|
||||||
|
try:
|
||||||
|
# FIXME: type ignore for numpy here
|
||||||
|
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
|
||||||
|
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||||
|
if np.isnan(normalized_embedding).any():
|
||||||
|
# for issue #11827 float values are not json compliant
|
||||||
|
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
|
||||||
|
continue
|
||||||
|
embedding_queue_embeddings.append(normalized_embedding)
|
||||||
|
except IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed transform embedding")
|
||||||
|
cache_embeddings = []
|
||||||
|
try:
|
||||||
|
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||||
|
multimodel_embeddings[i] = n_embedding
|
||||||
|
file_id = multimodel_documents[i]["file_id"]
|
||||||
|
if file_id not in cache_embeddings:
|
||||||
|
embedding_cache = Embedding(
|
||||||
|
model_name=self._model_instance.model,
|
||||||
|
hash=file_id,
|
||||||
|
provider_name=self._model_instance.provider,
|
||||||
|
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||||
|
)
|
||||||
|
embedding_cache.set_embedding(n_embedding)
|
||||||
|
db.session.add(embedding_cache)
|
||||||
|
cache_embeddings.append(file_id)
|
||||||
|
db.session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
except Exception as ex:
|
||||||
|
db.session.rollback()
|
||||||
|
logger.exception("Failed to embed documents")
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
return multimodel_embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> list[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings):
|
|||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
return embedding_results # type: ignore
|
return embedding_results # type: ignore
|
||||||
|
|
||||||
|
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||||
|
"""Embed multimodal documents."""
|
||||||
|
# use doc embedding cache or store if not exists
|
||||||
|
file_id = multimodel_document["file_id"]
|
||||||
|
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
|
||||||
|
embedding = redis_client.get(embedding_cache_key)
|
||||||
|
if embedding:
|
||||||
|
redis_client.expire(embedding_cache_key, 600)
|
||||||
|
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
|
||||||
|
return [float(x) for x in decoded_embedding]
|
||||||
|
try:
|
||||||
|
embedding_result = self._model_instance.invoke_multimodal_embedding(
|
||||||
|
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_results = embedding_result.embeddings[0]
|
||||||
|
# FIXME: type ignore for numpy here
|
||||||
|
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
|
||||||
|
if np.isnan(embedding_results).any():
|
||||||
|
raise ValueError("Normalized embedding is nan please try again")
|
||||||
|
except Exception as ex:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"])
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
try:
|
||||||
|
# encode embedding to base64
|
||||||
|
embedding_vector = np.array(embedding_results)
|
||||||
|
vector_bytes = embedding_vector.tobytes()
|
||||||
|
# Transform to Base64
|
||||||
|
encoded_vector = base64.b64encode(vector_bytes)
|
||||||
|
# Transform to string
|
||||||
|
encoded_str = encoded_vector.decode("utf-8")
|
||||||
|
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||||
|
except Exception as ex:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"]
|
||||||
|
)
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
return embedding_results # type: ignore
|
||||||
|
|||||||
@ -9,11 +9,21 @@ class Embeddings(ABC):
|
|||||||
"""Embed search docs."""
|
"""Embed search docs."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||||
|
"""Embed file documents."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def embed_query(self, text: str) -> list[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||||
|
"""Embed multimodal query."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Asynchronous Embed search docs."""
|
"""Asynchronous Embed search docs."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel):
|
|||||||
segment: DocumentSegment
|
segment: DocumentSegment
|
||||||
child_chunks: list[RetrievalChildChunk] | None = None
|
child_chunks: list[RetrievalChildChunk] | None = None
|
||||||
score: float | None = None
|
score: float | None = None
|
||||||
|
files: list[dict[str, str | int]] | None = None
|
||||||
|
|||||||
@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel):
|
|||||||
page: int | None = None
|
page: int | None = None
|
||||||
doc_metadata: dict[str, Any] | None = None
|
doc_metadata: dict[str, Any] | None = None
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
|||||||
6
api/core/rag/index_processor/constant/doc_type.py
Normal file
6
api/core/rag/index_processor/constant/doc_type.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class DocType(StrEnum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
@ -1,7 +1,12 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
class IndexType(StrEnum):
|
class IndexStructureType(StrEnum):
|
||||||
PARAGRAPH_INDEX = "text_model"
|
PARAGRAPH_INDEX = "text_model"
|
||||||
QA_INDEX = "qa_model"
|
QA_INDEX = "qa_model"
|
||||||
PARENT_CHILD_INDEX = "hierarchical_model"
|
PARENT_CHILD_INDEX = "hierarchical_model"
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTechniqueType(StrEnum):
|
||||||
|
ECONOMY = "economy"
|
||||||
|
HIGH_QUALITY = "high_quality"
|
||||||
|
|||||||
6
api/core/rag/index_processor/constant/query_type.py
Normal file
6
api/core/rag/index_processor/constant/query_type.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class QueryType(StrEnum):
|
||||||
|
TEXT_QUERY = "text_query"
|
||||||
|
IMAGE_QUERY = "image_query"
|
||||||
@ -1,20 +1,34 @@
|
|||||||
"""Abstract interface for document loader implementations."""
|
"""Abstract interface for document loader implementations."""
|
||||||
|
|
||||||
|
import cgi
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.models.document import Document
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.models.document import AttachmentDocument, Document
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.rag.splitter.fixed_text_splitter import (
|
from core.rag.splitter.fixed_text_splitter import (
|
||||||
EnhanceRecursiveCharacterTextSplitter,
|
EnhanceRecursiveCharacterTextSplitter,
|
||||||
FixedRecursiveCharacterTextSplitter,
|
FixedRecursiveCharacterTextSplitter,
|
||||||
)
|
)
|
||||||
from core.rag.splitter.text_splitter import TextSplitter
|
from core.rag.splitter.text_splitter import TextSplitter
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models import Account, ToolFile
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
def load(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
documents: list[Document],
|
||||||
|
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||||
|
with_keywords: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return character_splitter # type: ignore
|
return character_splitter # type: ignore
|
||||||
|
|
||||||
|
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
|
||||||
|
"""
|
||||||
|
Get the content files from the document.
|
||||||
|
"""
|
||||||
|
multi_model_documents: list[AttachmentDocument] = []
|
||||||
|
text = document.page_content
|
||||||
|
images = self._extract_markdown_images(text)
|
||||||
|
if not images:
|
||||||
|
return multi_model_documents
|
||||||
|
upload_file_id_list = []
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
# Collect all upload_file_ids including duplicates to preserve occurrence count
|
||||||
|
|
||||||
|
# For data before v0.10.0
|
||||||
|
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||||
|
match = re.search(pattern, image)
|
||||||
|
if match:
|
||||||
|
upload_file_id = match.group(1)
|
||||||
|
upload_file_id_list.append(upload_file_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For data after v0.10.0
|
||||||
|
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||||
|
match = re.search(pattern, image)
|
||||||
|
if match:
|
||||||
|
upload_file_id = match.group(1)
|
||||||
|
upload_file_id_list.append(upload_file_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||||
|
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
|
||||||
|
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||||
|
match = re.search(pattern, image)
|
||||||
|
if match:
|
||||||
|
if current_user:
|
||||||
|
tool_file_id = match.group(1)
|
||||||
|
upload_file_id = self._download_tool_file(tool_file_id, current_user)
|
||||||
|
if upload_file_id:
|
||||||
|
upload_file_id_list.append(upload_file_id)
|
||||||
|
continue
|
||||||
|
if current_user:
|
||||||
|
upload_file_id = self._download_image(image.split(" ")[0], current_user)
|
||||||
|
if upload_file_id:
|
||||||
|
upload_file_id_list.append(upload_file_id)
|
||||||
|
|
||||||
|
if not upload_file_id_list:
|
||||||
|
return multi_model_documents
|
||||||
|
|
||||||
|
# Get unique IDs for database query
|
||||||
|
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||||
|
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
|
||||||
|
|
||||||
|
# Create a mapping from ID to UploadFile for quick lookup
|
||||||
|
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
|
||||||
|
|
||||||
|
# Create a Document for each occurrence (including duplicates)
|
||||||
|
for upload_file_id in upload_file_id_list:
|
||||||
|
upload_file = upload_file_map.get(upload_file_id)
|
||||||
|
if upload_file:
|
||||||
|
multi_model_documents.append(
|
||||||
|
AttachmentDocument(
|
||||||
|
page_content=upload_file.name,
|
||||||
|
metadata={
|
||||||
|
"doc_id": upload_file.id,
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": document.metadata.get("document_id"),
|
||||||
|
"dataset_id": document.metadata.get("dataset_id"),
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return multi_model_documents
|
||||||
|
|
||||||
|
def _extract_markdown_images(self, text: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract the markdown images from the text.
|
||||||
|
"""
|
||||||
|
pattern = r"!\[.*?\]\((.*?)\)"
|
||||||
|
return re.findall(pattern, text)
|
||||||
|
|
||||||
|
def _download_image(self, image_url: str, current_user: Account) -> str | None:
|
||||||
|
"""
|
||||||
|
Download the image from the URL.
|
||||||
|
Image size must not exceed 2MB.
|
||||||
|
"""
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
|
||||||
|
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download with timeout
|
||||||
|
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Check Content-Length header if available
|
||||||
|
content_length = response.headers.get("Content-Length")
|
||||||
|
if content_length and int(content_length) > MAX_IMAGE_SIZE:
|
||||||
|
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
|
||||||
|
return None
|
||||||
|
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
content_disposition = response.headers.get("content-disposition")
|
||||||
|
if content_disposition:
|
||||||
|
_, params = cgi.parse_header(content_disposition)
|
||||||
|
if "filename" in params:
|
||||||
|
filename = params["filename"]
|
||||||
|
filename = unquote(filename)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
parsed_url = urlparse(image_url)
|
||||||
|
# unquote 处理 URL 中的中文
|
||||||
|
path = unquote(parsed_url.path)
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
filename = "downloaded_image_file"
|
||||||
|
|
||||||
|
name, current_ext = os.path.splitext(filename)
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "").split(";")[0].strip()
|
||||||
|
|
||||||
|
real_ext = mimetypes.guess_extension(content_type)
|
||||||
|
|
||||||
|
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
|
||||||
|
filename = f"{name}{real_ext}"
|
||||||
|
# Download content with size limit
|
||||||
|
blob = b""
|
||||||
|
for chunk in response.iter_bytes(chunk_size=8192):
|
||||||
|
blob += chunk
|
||||||
|
if len(blob) > MAX_IMAGE_SIZE:
|
||||||
|
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not blob:
|
||||||
|
logging.warning("Image from %s is empty", image_url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
upload_file = FileService(db.engine).upload_file(
|
||||||
|
filename=filename,
|
||||||
|
content=blob,
|
||||||
|
mimetype=content_type,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
return upload_file.id
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logging.warning("Error downloading image from %s: %s", image_url, str(e))
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Unexpected error downloading image from %s", image_url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
|
||||||
|
"""
|
||||||
|
Download the tool file from the ID.
|
||||||
|
"""
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||||
|
if not tool_file:
|
||||||
|
return None
|
||||||
|
blob = storage.load_once(tool_file.file_key)
|
||||||
|
upload_file = FileService(db.engine).upload_file(
|
||||||
|
filename=tool_file.name,
|
||||||
|
content=blob,
|
||||||
|
mimetype=tool_file.mimetype,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
return upload_file.id
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Abstract interface for document loader implementations."""
|
"""Abstract interface for document loader implementations."""
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||||
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
||||||
@ -19,11 +19,11 @@ class IndexProcessorFactory:
|
|||||||
if not self._index_type:
|
if not self._index_type:
|
||||||
raise ValueError("Index type must be specified.")
|
raise ValueError("Index type must be specified.")
|
||||||
|
|
||||||
if self._index_type == IndexType.PARAGRAPH_INDEX:
|
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
|
||||||
return ParagraphIndexProcessor()
|
return ParagraphIndexProcessor()
|
||||||
elif self._index_type == IndexType.QA_INDEX:
|
elif self._index_type == IndexStructureType.QA_INDEX:
|
||||||
return QAIndexProcessor()
|
return QAIndexProcessor()
|
||||||
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
|
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
return ParentChildIndexProcessor()
|
return ParentChildIndexProcessor()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||||
|
|||||||
@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
|||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from models.account import Account
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from services.account_service import AccountService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
return text_docs
|
return text_docs
|
||||||
|
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||||
process_rule = kwargs.get("process_rule")
|
process_rule = kwargs.get("process_rule")
|
||||||
if not process_rule:
|
if not process_rule:
|
||||||
raise ValueError("No process rule found.")
|
raise ValueError("No process rule found.")
|
||||||
@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
if document_node.metadata is not None:
|
if document_node.metadata is not None:
|
||||||
document_node.metadata["doc_id"] = doc_id
|
document_node.metadata["doc_id"] = doc_id
|
||||||
document_node.metadata["doc_hash"] = hash
|
document_node.metadata["doc_hash"] = hash
|
||||||
|
multimodal_documents = (
|
||||||
|
self._get_content_files(document_node, current_user) if document_node.metadata else None
|
||||||
|
)
|
||||||
|
if multimodal_documents:
|
||||||
|
document_node.attachments = multimodal_documents
|
||||||
# delete Splitter character
|
# delete Splitter character
|
||||||
page_content = remove_leading_symbols(document_node.page_content).strip()
|
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||||
if len(page_content) > 0:
|
if len(page_content) > 0:
|
||||||
@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
all_documents.extend(split_documents)
|
all_documents.extend(split_documents)
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
def load(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
documents: list[Document],
|
||||||
|
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||||
|
with_keywords: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
vector.create(documents)
|
vector.create(documents)
|
||||||
|
if multimodal_documents and dataset.is_multimodal:
|
||||||
|
vector.create_multimodal(multimodal_documents)
|
||||||
with_keywords = False
|
with_keywords = False
|
||||||
if with_keywords:
|
if with_keywords:
|
||||||
keywords_list = kwargs.get("keywords_list")
|
keywords_list = kwargs.get("keywords_list")
|
||||||
@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||||
|
documents: list[Any] = []
|
||||||
|
all_multimodal_documents: list[Any] = []
|
||||||
if isinstance(chunks, list):
|
if isinstance(chunks, list):
|
||||||
documents = []
|
|
||||||
for content in chunks:
|
for content in chunks:
|
||||||
metadata = {
|
metadata = {
|
||||||
"dataset_id": dataset.id,
|
"dataset_id": dataset.id,
|
||||||
@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
"doc_hash": helper.generate_text_hash(content),
|
"doc_hash": helper.generate_text_hash(content),
|
||||||
}
|
}
|
||||||
doc = Document(page_content=content, metadata=metadata)
|
doc = Document(page_content=content, metadata=metadata)
|
||||||
|
attachments = self._get_content_files(doc)
|
||||||
|
if attachments:
|
||||||
|
doc.attachments = attachments
|
||||||
|
all_multimodal_documents.extend(attachments)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
if documents:
|
|
||||||
# save node to document segment
|
|
||||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
|
||||||
# add document segments
|
|
||||||
doc_store.add_documents(docs=documents, save_child=False)
|
|
||||||
if dataset.indexing_technique == "high_quality":
|
|
||||||
vector = Vector(dataset)
|
|
||||||
vector.create(documents)
|
|
||||||
elif dataset.indexing_technique == "economy":
|
|
||||||
keyword = Keyword(dataset)
|
|
||||||
keyword.add_texts(documents)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Chunks is not a list")
|
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
|
||||||
|
for general_chunk in multimodal_general_structure.general_chunks:
|
||||||
|
metadata = {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"document_id": document.id,
|
||||||
|
"doc_id": str(uuid.uuid4()),
|
||||||
|
"doc_hash": helper.generate_text_hash(general_chunk.content),
|
||||||
|
}
|
||||||
|
doc = Document(page_content=general_chunk.content, metadata=metadata)
|
||||||
|
if general_chunk.files:
|
||||||
|
attachments = []
|
||||||
|
for file in general_chunk.files:
|
||||||
|
file_metadata = {
|
||||||
|
"doc_id": file.id,
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": document.id,
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
}
|
||||||
|
file_document = AttachmentDocument(
|
||||||
|
page_content=file.filename or "image_file", metadata=file_metadata
|
||||||
|
)
|
||||||
|
attachments.append(file_document)
|
||||||
|
all_multimodal_documents.append(file_document)
|
||||||
|
doc.attachments = attachments
|
||||||
|
else:
|
||||||
|
account = AccountService.load_user(document.created_by)
|
||||||
|
if not account:
|
||||||
|
raise ValueError("Invalid account")
|
||||||
|
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||||
|
if doc.attachments:
|
||||||
|
all_multimodal_documents.extend(doc.attachments)
|
||||||
|
documents.append(doc)
|
||||||
|
if documents:
|
||||||
|
# save node to document segment
|
||||||
|
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||||
|
# add document segments
|
||||||
|
doc_store.add_documents(docs=documents, save_child=False)
|
||||||
|
if dataset.indexing_technique == "high_quality":
|
||||||
|
vector = Vector(dataset)
|
||||||
|
vector.create(documents)
|
||||||
|
if all_multimodal_documents:
|
||||||
|
vector.create_multimodal(all_multimodal_documents)
|
||||||
|
elif dataset.indexing_technique == "economy":
|
||||||
|
keyword = Keyword(dataset)
|
||||||
|
keyword.add_texts(documents)
|
||||||
|
|
||||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||||
if isinstance(chunks, list):
|
if isinstance(chunks, list):
|
||||||
preview = []
|
preview = []
|
||||||
for content in chunks:
|
for content in chunks:
|
||||||
preview.append({"content": content})
|
preview.append({"content": content})
|
||||||
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
|
return {
|
||||||
|
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
||||||
|
"preview": preview,
|
||||||
|
"total_segments": len(chunks),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError("Chunks is not a list")
|
raise ValueError("Chunks is not a list")
|
||||||
|
|||||||
@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
|||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from models import Account
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from services.account_service import AccountService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
return text_docs
|
return text_docs
|
||||||
|
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||||
process_rule = kwargs.get("process_rule")
|
process_rule = kwargs.get("process_rule")
|
||||||
if not process_rule:
|
if not process_rule:
|
||||||
raise ValueError("No process rule found.")
|
raise ValueError("No process rule found.")
|
||||||
@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
page_content = page_content
|
page_content = page_content
|
||||||
if len(page_content) > 0:
|
if len(page_content) > 0:
|
||||||
document_node.page_content = page_content
|
document_node.page_content = page_content
|
||||||
|
multimodel_documents = self._get_content_files(document_node, current_user)
|
||||||
|
if multimodel_documents:
|
||||||
|
document_node.attachments = multimodel_documents
|
||||||
# parse document to child nodes
|
# parse document to child nodes
|
||||||
child_nodes = self._split_child_nodes(
|
child_nodes = self._split_child_nodes(
|
||||||
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||||
@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
elif rules.parent_mode == ParentMode.FULL_DOC:
|
elif rules.parent_mode == ParentMode.FULL_DOC:
|
||||||
page_content = "\n".join([document.page_content for document in documents])
|
page_content = "\n".join([document.page_content for document in documents])
|
||||||
document = Document(page_content=page_content, metadata=documents[0].metadata)
|
document = Document(page_content=page_content, metadata=documents[0].metadata)
|
||||||
|
multimodel_documents = self._get_content_files(document)
|
||||||
|
if multimodel_documents:
|
||||||
|
document.attachments = multimodel_documents
|
||||||
# parse document to child nodes
|
# parse document to child nodes
|
||||||
child_nodes = self._split_child_nodes(
|
child_nodes = self._split_child_nodes(
|
||||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||||
@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
def load(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
documents: list[Document],
|
||||||
|
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||||
|
with_keywords: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
Document.model_validate(child_document.model_dump()) for child_document in child_documents
|
Document.model_validate(child_document.model_dump()) for child_document in child_documents
|
||||||
]
|
]
|
||||||
vector.create(formatted_child_documents)
|
vector.create(formatted_child_documents)
|
||||||
|
if multimodal_documents and dataset.is_multimodal:
|
||||||
|
vector.create_multimodal(multimodal_documents)
|
||||||
|
|
||||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||||
# node_ids is segment's node_ids
|
# node_ids is segment's node_ids
|
||||||
@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
}
|
}
|
||||||
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
|
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
|
||||||
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
|
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
|
||||||
|
if parent_child.files and len(parent_child.files) > 0:
|
||||||
|
attachments = []
|
||||||
|
for file in parent_child.files:
|
||||||
|
file_metadata = {
|
||||||
|
"doc_id": file.id,
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": document.id,
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
}
|
||||||
|
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
|
||||||
|
attachments.append(file_document)
|
||||||
|
doc.attachments = attachments
|
||||||
|
else:
|
||||||
|
account = AccountService.load_user(document.created_by)
|
||||||
|
if not account:
|
||||||
|
raise ValueError("Invalid account")
|
||||||
|
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
if documents:
|
if documents:
|
||||||
# update document parent mode
|
# update document parent mode
|
||||||
@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
doc_store.add_documents(docs=documents, save_child=True)
|
doc_store.add_documents(docs=documents, save_child=True)
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
all_child_documents = []
|
all_child_documents = []
|
||||||
|
all_multimodal_documents = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
if doc.children:
|
if doc.children:
|
||||||
all_child_documents.extend(doc.children)
|
all_child_documents.extend(doc.children)
|
||||||
|
if doc.attachments:
|
||||||
|
all_multimodal_documents.extend(doc.attachments)
|
||||||
|
vector = Vector(dataset)
|
||||||
if all_child_documents:
|
if all_child_documents:
|
||||||
vector = Vector(dataset)
|
|
||||||
vector.create(all_child_documents)
|
vector.create(all_child_documents)
|
||||||
|
if all_multimodal_documents:
|
||||||
|
vector.create_multimodal(all_multimodal_documents)
|
||||||
|
|
||||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||||
@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
for parent_child in parent_childs.parent_child_chunks:
|
for parent_child in parent_childs.parent_child_chunks:
|
||||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||||
return {
|
return {
|
||||||
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
|
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
||||||
"parent_mode": parent_childs.parent_mode,
|
"parent_mode": parent_childs.parent_mode,
|
||||||
"preview": preview,
|
"preview": preview,
|
||||||
"total_segments": len(parent_childs.parent_child_chunks),
|
"total_segments": len(parent_childs.parent_child_chunks),
|
||||||
|
|||||||
@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
|||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document, QAStructureChunk
|
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from models.account import Account
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||||
@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
)
|
)
|
||||||
return text_docs
|
return text_docs
|
||||||
|
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||||
preview = kwargs.get("preview")
|
preview = kwargs.get("preview")
|
||||||
process_rule = kwargs.get("process_rule")
|
process_rule = kwargs.get("process_rule")
|
||||||
if not process_rule:
|
if not process_rule:
|
||||||
@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Skip the first row
|
# Skip the first row
|
||||||
df = pd.read_csv(file)
|
df = pd.read_csv(file) # type: ignore
|
||||||
text_docs = []
|
text_docs = []
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||||
@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
return text_docs
|
return text_docs
|
||||||
|
|
||||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
def load(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
documents: list[Document],
|
||||||
|
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||||
|
with_keywords: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
vector.create(documents)
|
vector.create(documents)
|
||||||
|
if multimodal_documents and dataset.is_multimodal:
|
||||||
|
vector.create_multimodal(multimodal_documents)
|
||||||
|
|
||||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
for qa_chunk in qa_chunks.qa_chunks:
|
for qa_chunk in qa_chunks.qa_chunks:
|
||||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||||
return {
|
return {
|
||||||
"chunk_structure": IndexType.QA_INDEX,
|
"chunk_structure": IndexStructureType.QA_INDEX,
|
||||||
"qa_preview": preview,
|
"qa_preview": preview,
|
||||||
"total_segments": len(qa_chunks.qa_chunks),
|
"total_segments": len(qa_chunks.qa_chunks),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
|
||||||
|
|
||||||
class ChildDocument(BaseModel):
|
class ChildDocument(BaseModel):
|
||||||
"""Class for storing a piece of text and associated metadata."""
|
"""Class for storing a piece of text and associated metadata."""
|
||||||
@ -15,7 +17,19 @@ class ChildDocument(BaseModel):
|
|||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
documents, etc.).
|
documents, etc.).
|
||||||
"""
|
"""
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentDocument(BaseModel):
|
||||||
|
"""Class for storing a piece of text and associated metadata."""
|
||||||
|
|
||||||
|
page_content: str
|
||||||
|
|
||||||
|
provider: str | None = "dify"
|
||||||
|
|
||||||
|
vector: list[float] | None = None
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
@ -28,12 +42,31 @@ class Document(BaseModel):
|
|||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
documents, etc.).
|
documents, etc.).
|
||||||
"""
|
"""
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
provider: str | None = "dify"
|
provider: str | None = "dify"
|
||||||
|
|
||||||
children: list[ChildDocument] | None = None
|
children: list[ChildDocument] | None = None
|
||||||
|
|
||||||
|
attachments: list[AttachmentDocument] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
General Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
files: list[File] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalGeneralStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Multimodal General Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
general_chunks: list[GeneralChunk]
|
||||||
|
|
||||||
|
|
||||||
class GeneralStructureChunk(BaseModel):
|
class GeneralStructureChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel):
|
|||||||
|
|
||||||
parent_content: str
|
parent_content: str
|
||||||
child_contents: list[str]
|
child_contents: list[str]
|
||||||
|
files: list[File] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ParentChildStructureChunk(BaseModel):
|
class ParentChildStructureChunk(BaseModel):
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ class BaseRerankRunner(ABC):
|
|||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
top_n: int | None = None,
|
top_n: int | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Run rerank model
|
Run rerank model
|
||||||
|
|||||||
@ -1,6 +1,15 @@
|
|||||||
from core.model_manager import ModelInstance
|
import base64
|
||||||
|
|
||||||
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
|
||||||
class RerankModelRunner(BaseRerankRunner):
|
class RerankModelRunner(BaseRerankRunner):
|
||||||
@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner):
|
|||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
top_n: int | None = None,
|
top_n: int | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Run rerank model
|
Run rerank model
|
||||||
@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner):
|
|||||||
:param user: unique user id if needed
|
:param user: unique user id if needed
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
model_manager = ModelManager()
|
||||||
|
is_support_vision = model_manager.check_model_support_vision(
|
||||||
|
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||||
|
provider=self.rerank_model_instance.provider,
|
||||||
|
model=self.rerank_model_instance.model,
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
)
|
||||||
|
if not is_support_vision:
|
||||||
|
if query_type == QueryType.TEXT_QUERY:
|
||||||
|
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
||||||
|
else:
|
||||||
|
return documents
|
||||||
|
else:
|
||||||
|
rerank_result, unique_documents = self.fetch_multimodal_rerank(
|
||||||
|
query, documents, score_threshold, top_n, user, query_type
|
||||||
|
)
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for result in rerank_result.docs:
|
||||||
|
if score_threshold is None or result.score >= score_threshold:
|
||||||
|
# format document
|
||||||
|
rerank_document = Document(
|
||||||
|
page_content=result.text,
|
||||||
|
metadata=unique_documents[result.index].metadata,
|
||||||
|
provider=unique_documents[result.index].provider,
|
||||||
|
)
|
||||||
|
if rerank_document.metadata is not None:
|
||||||
|
rerank_document.metadata["score"] = result.score
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
|
||||||
|
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||||
|
|
||||||
|
def fetch_text_rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: list[Document],
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
top_n: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> tuple[RerankResult, list[Document]]:
|
||||||
|
"""
|
||||||
|
Fetch text rerank
|
||||||
|
:param query: search query
|
||||||
|
:param documents: documents for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id if needed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
docs = []
|
docs = []
|
||||||
doc_ids = set()
|
doc_ids = set()
|
||||||
unique_documents = []
|
unique_documents = []
|
||||||
@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner):
|
|||||||
and document.metadata is not None
|
and document.metadata is not None
|
||||||
and document.metadata["doc_id"] not in doc_ids
|
and document.metadata["doc_id"] not in doc_ids
|
||||||
):
|
):
|
||||||
doc_ids.add(document.metadata["doc_id"])
|
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
|
||||||
docs.append(document.page_content)
|
doc_ids.add(document.metadata["doc_id"])
|
||||||
unique_documents.append(document)
|
docs.append(document.page_content)
|
||||||
|
unique_documents.append(document)
|
||||||
elif document.provider == "external":
|
elif document.provider == "external":
|
||||||
if document not in unique_documents:
|
if document not in unique_documents:
|
||||||
docs.append(document.page_content)
|
docs.append(document.page_content)
|
||||||
unique_documents.append(document)
|
unique_documents.append(document)
|
||||||
|
|
||||||
documents = unique_documents
|
|
||||||
|
|
||||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||||
)
|
)
|
||||||
|
return rerank_result, unique_documents
|
||||||
|
|
||||||
rerank_documents = []
|
def fetch_multimodal_rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: list[Document],
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
top_n: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||||
|
) -> tuple[RerankResult, list[Document]]:
|
||||||
|
"""
|
||||||
|
Fetch multimodal rerank
|
||||||
|
:param query: search query
|
||||||
|
:param documents: documents for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id if needed
|
||||||
|
:param query_type: query type
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
doc_ids = set()
|
||||||
|
unique_documents = []
|
||||||
|
for document in documents:
|
||||||
|
if (
|
||||||
|
document.provider == "dify"
|
||||||
|
and document.metadata is not None
|
||||||
|
and document.metadata["doc_id"] not in doc_ids
|
||||||
|
):
|
||||||
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
|
# Query file info within db.session context to ensure thread-safe access
|
||||||
|
upload_file = (
|
||||||
|
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
|
||||||
|
)
|
||||||
|
if upload_file:
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
document_file_base64 = base64.b64encode(blob).decode()
|
||||||
|
document_file_dict = {
|
||||||
|
"content": document_file_base64,
|
||||||
|
"content_type": document.metadata["doc_type"],
|
||||||
|
}
|
||||||
|
docs.append(document_file_dict)
|
||||||
|
else:
|
||||||
|
document_text_dict = {
|
||||||
|
"content": document.page_content,
|
||||||
|
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||||
|
}
|
||||||
|
docs.append(document_text_dict)
|
||||||
|
doc_ids.add(document.metadata["doc_id"])
|
||||||
|
unique_documents.append(document)
|
||||||
|
elif document.provider == "external":
|
||||||
|
if document not in unique_documents:
|
||||||
|
docs.append(
|
||||||
|
{
|
||||||
|
"content": document.page_content,
|
||||||
|
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
unique_documents.append(document)
|
||||||
|
|
||||||
for result in rerank_result.docs:
|
documents = unique_documents
|
||||||
if score_threshold is None or result.score >= score_threshold:
|
if query_type == QueryType.TEXT_QUERY:
|
||||||
# format document
|
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
|
||||||
rerank_document = Document(
|
return rerank_result, unique_documents
|
||||||
page_content=result.text,
|
elif query_type == QueryType.IMAGE_QUERY:
|
||||||
metadata=documents[result.index].metadata,
|
# Query file info within db.session context to ensure thread-safe access
|
||||||
provider=documents[result.index].provider,
|
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
|
||||||
|
if upload_file:
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
file_query = base64.b64encode(blob).decode()
|
||||||
|
file_query_dict = {
|
||||||
|
"content": file_query,
|
||||||
|
"content_type": DocType.IMAGE,
|
||||||
|
}
|
||||||
|
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||||
|
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||||
)
|
)
|
||||||
if rerank_document.metadata is not None:
|
return rerank_result, unique_documents
|
||||||
rerank_document.metadata["score"] = result.score
|
else:
|
||||||
rerank_documents.append(rerank_document)
|
raise ValueError(f"Upload file not found for query: {query}")
|
||||||
|
|
||||||
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
|
else:
|
||||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
raise ValueError(f"Query type {query_type} is not supported")
|
||||||
|
|||||||
@ -7,6 +7,8 @@ from core.model_manager import ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||||
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.entity.weight import VectorSetting, Weights
|
from core.rag.rerank.entity.weight import VectorSetting, Weights
|
||||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||||
@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
|||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
top_n: int | None = None,
|
top_n: int | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
|
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Run rerank model
|
Run rerank model
|
||||||
@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner):
|
|||||||
and document.metadata is not None
|
and document.metadata is not None
|
||||||
and document.metadata["doc_id"] not in doc_ids
|
and document.metadata["doc_id"] not in doc_ids
|
||||||
):
|
):
|
||||||
doc_ids.add(document.metadata["doc_id"])
|
# weight rerank only support text documents
|
||||||
unique_documents.append(document)
|
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
|
||||||
|
doc_ids.add(document.metadata["doc_id"])
|
||||||
|
unique_documents.append(document)
|
||||||
else:
|
else:
|
||||||
if document not in unique_documents:
|
if document not in unique_documents:
|
||||||
unique_documents.append(document)
|
unique_documents.append(document)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Any, Union, cast
|
|||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from sqlalchemy import and_, or_, select
|
from sqlalchemy import and_, or_, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
DatasetEntity,
|
DatasetEntity,
|
||||||
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
|
|||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
|
|||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.rerank_type import RerankMode
|
from core.rag.rerank.rerank_type import RerankMode
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
|
|||||||
METADATA_FILTER_USER_PROMPT_2,
|
METADATA_FILTER_USER_PROMPT_2,
|
||||||
METADATA_FILTER_USER_PROMPT_3,
|
METADATA_FILTER_USER_PROMPT_3,
|
||||||
)
|
)
|
||||||
|
from core.tools.signature import sign_upload_file
|
||||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
|
from models import UploadFile
|
||||||
|
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
|
|
||||||
@ -99,7 +105,8 @@ class DatasetRetrieval:
|
|||||||
message_id: str,
|
message_id: str,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
inputs: Mapping[str, Any] | None = None,
|
inputs: Mapping[str, Any] | None = None,
|
||||||
) -> str | None:
|
vision_enabled: bool = False,
|
||||||
|
) -> tuple[str | None, list[File] | None]:
|
||||||
"""
|
"""
|
||||||
Retrieve dataset.
|
Retrieve dataset.
|
||||||
:param app_id: app_id
|
:param app_id: app_id
|
||||||
@ -118,7 +125,7 @@ class DatasetRetrieval:
|
|||||||
"""
|
"""
|
||||||
dataset_ids = config.dataset_ids
|
dataset_ids = config.dataset_ids
|
||||||
if len(dataset_ids) == 0:
|
if len(dataset_ids) == 0:
|
||||||
return None
|
return None, []
|
||||||
retrieve_config = config.retrieve_config
|
retrieve_config = config.retrieve_config
|
||||||
|
|
||||||
# check model is support tool calling
|
# check model is support tool calling
|
||||||
@ -136,7 +143,7 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
return None
|
return None, []
|
||||||
|
|
||||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||||
features = model_schema.features
|
features = model_schema.features
|
||||||
@ -182,8 +189,8 @@ class DatasetRetrieval:
|
|||||||
tenant_id,
|
tenant_id,
|
||||||
user_id,
|
user_id,
|
||||||
user_from,
|
user_from,
|
||||||
available_datasets,
|
|
||||||
query,
|
query,
|
||||||
|
available_datasets,
|
||||||
model_instance,
|
model_instance,
|
||||||
model_config,
|
model_config,
|
||||||
planning_strategy,
|
planning_strategy,
|
||||||
@ -213,6 +220,7 @@ class DatasetRetrieval:
|
|||||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||||
document_context_list: list[DocumentContext] = []
|
document_context_list: list[DocumentContext] = []
|
||||||
|
context_files: list[File] = []
|
||||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||||
# deal with external documents
|
# deal with external documents
|
||||||
for item in external_documents:
|
for item in external_documents:
|
||||||
@ -248,6 +256,31 @@ class DatasetRetrieval:
|
|||||||
score=record.score,
|
score=record.score,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if vision_enabled:
|
||||||
|
attachments_with_bindings = db.session.execute(
|
||||||
|
select(SegmentAttachmentBinding, UploadFile)
|
||||||
|
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||||
|
.where(
|
||||||
|
SegmentAttachmentBinding.segment_id == segment.id,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
if attachments_with_bindings:
|
||||||
|
for _, upload_file in attachments_with_bindings:
|
||||||
|
attchment_info = File(
|
||||||
|
id=upload_file.id,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension="." + upload_file.extension,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
tenant_id=segment.tenant_id,
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
remote_url=upload_file.source_url,
|
||||||
|
related_id=upload_file.id,
|
||||||
|
size=upload_file.size,
|
||||||
|
storage_key=upload_file.key,
|
||||||
|
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||||
|
)
|
||||||
|
context_files.append(attchment_info)
|
||||||
if show_retrieve_source:
|
if show_retrieve_source:
|
||||||
for record in records:
|
for record in records:
|
||||||
segment = record.segment
|
segment = record.segment
|
||||||
@ -288,8 +321,10 @@ class DatasetRetrieval:
|
|||||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||||
if document_context_list:
|
if document_context_list:
|
||||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
return str(
|
||||||
return ""
|
"\n".join([document_context.content for document_context in document_context_list])
|
||||||
|
), context_files
|
||||||
|
return "", context_files
|
||||||
|
|
||||||
def single_retrieve(
|
def single_retrieve(
|
||||||
self,
|
self,
|
||||||
@ -297,8 +332,8 @@ class DatasetRetrieval:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
user_from: str,
|
user_from: str,
|
||||||
available_datasets: list,
|
|
||||||
query: str,
|
query: str,
|
||||||
|
available_datasets: list,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
planning_strategy: PlanningStrategy,
|
planning_strategy: PlanningStrategy,
|
||||||
@ -336,7 +371,7 @@ class DatasetRetrieval:
|
|||||||
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||||
|
|
||||||
self._record_usage(router_usage)
|
self._record_usage(router_usage)
|
||||||
|
timer = None
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
# get retrieval model config
|
# get retrieval model config
|
||||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||||
@ -406,10 +441,19 @@ class DatasetRetrieval:
|
|||||||
weights=retrieval_model_config.get("weights", None),
|
weights=retrieval_model_config.get("weights", None),
|
||||||
document_ids_filter=document_ids_filter,
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
self._on_retrieval_end(results, message_id, timer)
|
thread = threading.Thread(
|
||||||
|
target=self._on_retrieval_end,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"documents": results,
|
||||||
|
"message_id": message_id,
|
||||||
|
"timer": timer,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
return []
|
return []
|
||||||
@ -421,7 +465,7 @@ class DatasetRetrieval:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
user_from: str,
|
user_from: str,
|
||||||
available_datasets: list,
|
available_datasets: list,
|
||||||
query: str,
|
query: str | None,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
reranking_mode: str,
|
reranking_mode: str,
|
||||||
@ -431,10 +475,11 @@ class DatasetRetrieval:
|
|||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataCondition | None = None,
|
||||||
|
attachment_ids: list[str] | None = None,
|
||||||
):
|
):
|
||||||
if not available_datasets:
|
if not available_datasets:
|
||||||
return []
|
return []
|
||||||
threads = []
|
all_threads = []
|
||||||
all_documents: list[Document] = []
|
all_documents: list[Document] = []
|
||||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||||
index_type_check = all(
|
index_type_check = all(
|
||||||
@ -467,131 +512,226 @@ class DatasetRetrieval:
|
|||||||
0
|
0
|
||||||
].embedding_model_provider
|
].embedding_model_provider
|
||||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||||
|
|
||||||
for dataset in available_datasets:
|
|
||||||
index_type = dataset.indexing_technique
|
|
||||||
document_ids_filter = None
|
|
||||||
if dataset.provider != "external":
|
|
||||||
if metadata_condition and not metadata_filter_document_ids:
|
|
||||||
continue
|
|
||||||
if metadata_filter_document_ids:
|
|
||||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
|
||||||
if document_ids:
|
|
||||||
document_ids_filter = document_ids
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
retrieval_thread = threading.Thread(
|
|
||||||
target=self._retriever,
|
|
||||||
kwargs={
|
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
|
||||||
"dataset_id": dataset.id,
|
|
||||||
"query": query,
|
|
||||||
"top_k": top_k,
|
|
||||||
"all_documents": all_documents,
|
|
||||||
"document_ids_filter": document_ids_filter,
|
|
||||||
"metadata_condition": metadata_condition,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
threads.append(retrieval_thread)
|
|
||||||
retrieval_thread.start()
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
with measure_time() as timer:
|
with measure_time() as timer:
|
||||||
if reranking_enable:
|
if query:
|
||||||
# do rerank for searched documents
|
query_thread = threading.Thread(
|
||||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
target=self._multiple_retrieve_thread,
|
||||||
|
kwargs={
|
||||||
all_documents = data_post_processor.invoke(
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
"available_datasets": available_datasets,
|
||||||
|
"metadata_condition": metadata_condition,
|
||||||
|
"metadata_filter_document_ids": metadata_filter_document_ids,
|
||||||
|
"all_documents": all_documents,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"reranking_enable": reranking_enable,
|
||||||
|
"reranking_mode": reranking_mode,
|
||||||
|
"reranking_model": reranking_model,
|
||||||
|
"weights": weights,
|
||||||
|
"top_k": top_k,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"query": query,
|
||||||
|
"attachment_id": None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
all_threads.append(query_thread)
|
||||||
if index_type == "economy":
|
query_thread.start()
|
||||||
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
|
if attachment_ids:
|
||||||
elif index_type == "high_quality":
|
for attachment_id in attachment_ids:
|
||||||
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
|
attachment_thread = threading.Thread(
|
||||||
else:
|
target=self._multiple_retrieve_thread,
|
||||||
all_documents = all_documents[:top_k] if top_k else all_documents
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
"available_datasets": available_datasets,
|
||||||
|
"metadata_condition": metadata_condition,
|
||||||
|
"metadata_filter_document_ids": metadata_filter_document_ids,
|
||||||
|
"all_documents": all_documents,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"reranking_enable": reranking_enable,
|
||||||
|
"reranking_mode": reranking_mode,
|
||||||
|
"reranking_model": reranking_model,
|
||||||
|
"weights": weights,
|
||||||
|
"top_k": top_k,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"query": None,
|
||||||
|
"attachment_id": attachment_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
all_threads.append(attachment_thread)
|
||||||
|
attachment_thread.start()
|
||||||
|
for thread in all_threads:
|
||||||
|
thread.join()
|
||||||
|
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||||
|
|
||||||
if all_documents:
|
if all_documents:
|
||||||
self._on_retrieval_end(all_documents, message_id, timer)
|
# add thread to call _on_retrieval_end
|
||||||
|
retrieval_end_thread = threading.Thread(
|
||||||
return all_documents
|
target=self._on_retrieval_end,
|
||||||
|
kwargs={
|
||||||
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
"""Handle retrieval end."""
|
"documents": all_documents,
|
||||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
"message_id": message_id,
|
||||||
for document in dify_documents:
|
"timer": timer,
|
||||||
if document.metadata is not None:
|
},
|
||||||
dataset_document_stmt = select(DatasetDocument).where(
|
|
||||||
DatasetDocument.id == document.metadata["document_id"]
|
|
||||||
)
|
|
||||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
|
||||||
if dataset_document:
|
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
|
||||||
child_chunk_stmt = select(ChildChunk).where(
|
|
||||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
|
||||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
|
||||||
ChildChunk.document_id == dataset_document.id,
|
|
||||||
)
|
|
||||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
|
||||||
if child_chunk:
|
|
||||||
_ = (
|
|
||||||
db.session.query(DocumentSegment)
|
|
||||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
|
||||||
.update(
|
|
||||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
|
||||||
synchronize_session=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = db.session.query(DocumentSegment).where(
|
|
||||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# if 'dataset_id' in document.metadata:
|
|
||||||
if "dataset_id" in document.metadata:
|
|
||||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
|
||||||
|
|
||||||
# add hit count to document segment
|
|
||||||
query.update(
|
|
||||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# get tracing instance
|
|
||||||
trace_manager: TraceQueueManager | None = (
|
|
||||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
|
||||||
)
|
|
||||||
if trace_manager:
|
|
||||||
trace_manager.add_trace_task(
|
|
||||||
TraceTask(
|
|
||||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
retrieval_end_thread.start()
|
||||||
|
retrieval_resource_list = []
|
||||||
|
doc_ids_filter = []
|
||||||
|
for document in all_documents:
|
||||||
|
if document.provider == "dify":
|
||||||
|
doc_id = document.metadata.get("doc_id")
|
||||||
|
if doc_id and doc_id not in doc_ids_filter:
|
||||||
|
doc_ids_filter.append(doc_id)
|
||||||
|
retrieval_resource_list.append(document)
|
||||||
|
elif document.provider == "external":
|
||||||
|
retrieval_resource_list.append(document)
|
||||||
|
return retrieval_resource_list
|
||||||
|
|
||||||
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
|
def _on_retrieval_end(
|
||||||
|
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
|
||||||
|
):
|
||||||
|
"""Handle retrieval end."""
|
||||||
|
with flask_app.app_context():
|
||||||
|
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||||
|
segment_ids = []
|
||||||
|
segment_index_node_ids = []
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
for document in dify_documents:
|
||||||
|
if document.metadata is not None:
|
||||||
|
dataset_document_stmt = select(DatasetDocument).where(
|
||||||
|
DatasetDocument.id == document.metadata["document_id"]
|
||||||
|
)
|
||||||
|
dataset_document = session.scalar(dataset_document_stmt)
|
||||||
|
if dataset_document:
|
||||||
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
|
segment_id = None
|
||||||
|
if (
|
||||||
|
"doc_type" not in document.metadata
|
||||||
|
or document.metadata.get("doc_type") == DocType.TEXT
|
||||||
|
):
|
||||||
|
child_chunk_stmt = select(ChildChunk).where(
|
||||||
|
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||||
|
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||||
|
ChildChunk.document_id == dataset_document.id,
|
||||||
|
)
|
||||||
|
child_chunk = session.scalar(child_chunk_stmt)
|
||||||
|
if child_chunk:
|
||||||
|
segment_id = child_chunk.segment_id
|
||||||
|
elif (
|
||||||
|
"doc_type" in document.metadata
|
||||||
|
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||||
|
):
|
||||||
|
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||||
|
dataset_document.dataset_id,
|
||||||
|
dataset_document.tenant_id,
|
||||||
|
document.metadata.get("doc_id") or "",
|
||||||
|
session,
|
||||||
|
)
|
||||||
|
if attachment_info_dict:
|
||||||
|
segment_id = attachment_info_dict["segment_id"]
|
||||||
|
if segment_id:
|
||||||
|
if segment_id not in segment_ids:
|
||||||
|
segment_ids.append(segment_id)
|
||||||
|
_ = (
|
||||||
|
session.query(DocumentSegment)
|
||||||
|
.where(DocumentSegment.id == segment_id)
|
||||||
|
.update(
|
||||||
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = None
|
||||||
|
if (
|
||||||
|
"doc_type" not in document.metadata
|
||||||
|
or document.metadata.get("doc_type") == DocType.TEXT
|
||||||
|
):
|
||||||
|
if document.metadata["doc_id"] not in segment_index_node_ids:
|
||||||
|
segment = (
|
||||||
|
session.query(DocumentSegment)
|
||||||
|
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if segment:
|
||||||
|
segment_index_node_ids.append(document.metadata["doc_id"])
|
||||||
|
segment_ids.append(segment.id)
|
||||||
|
query = session.query(DocumentSegment).where(
|
||||||
|
DocumentSegment.id == segment.id
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
"doc_type" in document.metadata
|
||||||
|
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||||
|
):
|
||||||
|
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||||
|
dataset_document.dataset_id,
|
||||||
|
dataset_document.tenant_id,
|
||||||
|
document.metadata.get("doc_id") or "",
|
||||||
|
session,
|
||||||
|
)
|
||||||
|
if attachment_info_dict:
|
||||||
|
segment_id = attachment_info_dict["segment_id"]
|
||||||
|
if segment_id not in segment_ids:
|
||||||
|
segment_ids.append(segment_id)
|
||||||
|
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
|
||||||
|
if query:
|
||||||
|
# if 'dataset_id' in document.metadata:
|
||||||
|
if "dataset_id" in document.metadata:
|
||||||
|
query = query.where(
|
||||||
|
DocumentSegment.dataset_id == document.metadata["dataset_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# add hit count to document segment
|
||||||
|
query.update(
|
||||||
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# get tracing instance
|
||||||
|
trace_manager: TraceQueueManager | None = (
|
||||||
|
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||||
|
)
|
||||||
|
if trace_manager:
|
||||||
|
trace_manager.add_trace_task(
|
||||||
|
TraceTask(
|
||||||
|
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_query(
|
||||||
|
self,
|
||||||
|
query: str | None,
|
||||||
|
attachment_ids: list[str] | None,
|
||||||
|
dataset_ids: list[str],
|
||||||
|
app_id: str,
|
||||||
|
user_from: str,
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle query.
|
Handle query.
|
||||||
"""
|
"""
|
||||||
if not query:
|
if not query and not attachment_ids:
|
||||||
return
|
return
|
||||||
dataset_queries = []
|
dataset_queries = []
|
||||||
for dataset_id in dataset_ids:
|
for dataset_id in dataset_ids:
|
||||||
dataset_query = DatasetQuery(
|
contents = []
|
||||||
dataset_id=dataset_id,
|
if query:
|
||||||
content=query,
|
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
|
||||||
source="app",
|
if attachment_ids:
|
||||||
source_app_id=app_id,
|
for attachment_id in attachment_ids:
|
||||||
created_by_role=user_from,
|
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
|
||||||
created_by=user_id,
|
if contents:
|
||||||
)
|
dataset_query = DatasetQuery(
|
||||||
dataset_queries.append(dataset_query)
|
dataset_id=dataset_id,
|
||||||
if dataset_queries:
|
content=json.dumps(contents),
|
||||||
db.session.add_all(dataset_queries)
|
source="app",
|
||||||
|
source_app_id=app_id,
|
||||||
|
created_by_role=user_from,
|
||||||
|
created_by=user_id,
|
||||||
|
)
|
||||||
|
dataset_queries.append(dataset_query)
|
||||||
|
if dataset_queries:
|
||||||
|
db.session.add_all(dataset_queries)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def _retriever(
|
def _retriever(
|
||||||
@ -603,6 +743,7 @@ class DatasetRetrieval:
|
|||||||
all_documents: list,
|
all_documents: list,
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataCondition | None = None,
|
||||||
|
attachment_ids: list[str] | None = None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||||
@ -611,7 +752,7 @@ class DatasetRetrieval:
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if dataset.provider == "external":
|
if dataset.provider == "external" and query:
|
||||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
@ -663,6 +804,7 @@ class DatasetRetrieval:
|
|||||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||||
weights=retrieval_model.get("weights", None),
|
weights=retrieval_model.get("weights", None),
|
||||||
document_ids_filter=document_ids_filter,
|
document_ids_filter=document_ids_filter,
|
||||||
|
attachment_ids=attachment_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
@ -1222,3 +1364,86 @@ class DatasetRetrieval:
|
|||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
return full_text, usage
|
return full_text, usage
|
||||||
|
|
||||||
|
def _multiple_retrieve_thread(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
available_datasets: list,
|
||||||
|
metadata_condition: MetadataCondition | None,
|
||||||
|
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||||
|
all_documents: list[Document],
|
||||||
|
tenant_id: str,
|
||||||
|
reranking_enable: bool,
|
||||||
|
reranking_mode: str,
|
||||||
|
reranking_model: dict | None,
|
||||||
|
weights: dict[str, Any] | None,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
query: str | None,
|
||||||
|
attachment_id: str | None,
|
||||||
|
):
|
||||||
|
with flask_app.app_context():
|
||||||
|
threads = []
|
||||||
|
all_documents_item: list[Document] = []
|
||||||
|
index_type = None
|
||||||
|
for dataset in available_datasets:
|
||||||
|
index_type = dataset.indexing_technique
|
||||||
|
document_ids_filter = None
|
||||||
|
if dataset.provider != "external":
|
||||||
|
if metadata_condition and not metadata_filter_document_ids:
|
||||||
|
continue
|
||||||
|
if metadata_filter_document_ids:
|
||||||
|
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||||
|
if document_ids:
|
||||||
|
document_ids_filter = document_ids
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
retrieval_thread = threading.Thread(
|
||||||
|
target=self._retriever,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": flask_app,
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"query": query,
|
||||||
|
"top_k": top_k,
|
||||||
|
"all_documents": all_documents_item,
|
||||||
|
"document_ids_filter": document_ids_filter,
|
||||||
|
"metadata_condition": metadata_condition,
|
||||||
|
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
threads.append(retrieval_thread)
|
||||||
|
retrieval_thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
if reranking_enable:
|
||||||
|
# do rerank for searched documents
|
||||||
|
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||||
|
if query:
|
||||||
|
all_documents_item = data_post_processor.invoke(
|
||||||
|
query=query,
|
||||||
|
documents=all_documents_item,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=top_k,
|
||||||
|
query_type=QueryType.TEXT_QUERY,
|
||||||
|
)
|
||||||
|
if attachment_id:
|
||||||
|
all_documents_item = data_post_processor.invoke(
|
||||||
|
documents=all_documents_item,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
top_n=top_k,
|
||||||
|
query_type=QueryType.IMAGE_QUERY,
|
||||||
|
query=attachment_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if index_type == IndexTechniqueType.ECONOMY:
|
||||||
|
if not query:
|
||||||
|
all_documents_item = []
|
||||||
|
else:
|
||||||
|
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||||
|
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||||
|
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||||
|
else:
|
||||||
|
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||||
|
if all_documents_item:
|
||||||
|
all_documents.extend(all_documents_item)
|
||||||
|
|||||||
@ -0,0 +1,65 @@
|
|||||||
|
{
|
||||||
|
"$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json",
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"type": "array",
|
||||||
|
"title": "Multimodal General Structure",
|
||||||
|
"description": "Schema for multimodal general structure (v1) - array of objects",
|
||||||
|
"properties": {
|
||||||
|
"general_chunks": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The content"
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file name"
|
||||||
|
},
|
||||||
|
"size": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "file size"
|
||||||
|
},
|
||||||
|
"extension": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file extension"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file type"
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file mime type"
|
||||||
|
},
|
||||||
|
"transfer_method": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file transfer method"
|
||||||
|
},
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file url"
|
||||||
|
},
|
||||||
|
"related_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file related id"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"description": "List of files"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["content"]
|
||||||
|
},
|
||||||
|
"description": "List of content and files"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,78 @@
|
|||||||
|
{
|
||||||
|
"$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json",
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"type": "object",
|
||||||
|
"title": "Multimodal Parent-Child Structure",
|
||||||
|
"description": "Schema for multimodal parent-child structure (v1)",
|
||||||
|
"properties": {
|
||||||
|
"parent_mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The mode of parent-child relationship"
|
||||||
|
},
|
||||||
|
"parent_child_chunks": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"parent_content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The parent content"
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file name"
|
||||||
|
},
|
||||||
|
"size": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "file size"
|
||||||
|
},
|
||||||
|
"extension": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file extension"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file type"
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file mime type"
|
||||||
|
},
|
||||||
|
"transfer_method": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file transfer method"
|
||||||
|
},
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file url"
|
||||||
|
},
|
||||||
|
"related_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "file related id"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"]
|
||||||
|
},
|
||||||
|
"description": "List of files"
|
||||||
|
},
|
||||||
|
"child_contents": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "List of child contents"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["parent_content", "child_contents"]
|
||||||
|
},
|
||||||
|
"description": "List of parent-child chunk pairs"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["parent_mode", "parent_child_chunks"]
|
||||||
|
}
|
||||||
@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
|||||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def sign_upload_file(upload_file_id: str, extension: str) -> str:
|
||||||
|
"""
|
||||||
|
sign file to get a temporary url for plugin access
|
||||||
|
"""
|
||||||
|
# Use internal URL for plugin/tool file access in Docker environments
|
||||||
|
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||||
|
file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
"""
|
"""
|
||||||
verify signature
|
verify signature
|
||||||
|
|||||||
@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
|
|||||||
"""
|
"""
|
||||||
# Match Unicode ranges for punctuation and symbols
|
# Match Unicode ranges for punctuation and symbols
|
||||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||||
return re.sub(pattern, "", text)
|
return re.sub(pattern, "", text)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.pause_reason import PauseReason
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
@ -14,6 +15,7 @@ from .base import NodeEventBase
|
|||||||
class RunRetrieverResourceEvent(NodeEventBase):
|
class RunRetrieverResourceEvent(NodeEventBase):
|
||||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||||
context: str = Field(..., description="context")
|
context: str = Field(..., description="context")
|
||||||
|
context_files: list[File] | None = Field(default=None, description="context files")
|
||||||
|
|
||||||
|
|
||||||
class ModelInvokeCompletedEvent(NodeEventBase):
|
class ModelInvokeCompletedEvent(NodeEventBase):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from collections.abc import Sequence
|
|||||||
from email.message import Message
|
from email.message import Message
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import charset_normalizer
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
@ -96,10 +97,12 @@ class HttpRequestNodeData(BaseNodeData):
|
|||||||
class Response:
|
class Response:
|
||||||
headers: dict[str, str]
|
headers: dict[str, str]
|
||||||
response: httpx.Response
|
response: httpx.Response
|
||||||
|
_cached_text: str | None
|
||||||
|
|
||||||
def __init__(self, response: httpx.Response):
|
def __init__(self, response: httpx.Response):
|
||||||
self.response = response
|
self.response = response
|
||||||
self.headers = dict(response.headers)
|
self.headers = dict(response.headers)
|
||||||
|
self._cached_text = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_file(self):
|
def is_file(self):
|
||||||
@ -159,7 +162,31 @@ class Response:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
return self.response.text
|
"""
|
||||||
|
Get response text with robust encoding detection.
|
||||||
|
|
||||||
|
Uses charset_normalizer for better encoding detection than httpx's default,
|
||||||
|
which helps handle Chinese and other non-ASCII characters properly.
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
if hasattr(self, "_cached_text") and self._cached_text is not None:
|
||||||
|
return self._cached_text
|
||||||
|
|
||||||
|
# Try charset_normalizer for robust encoding detection first
|
||||||
|
detected_encoding = charset_normalizer.from_bytes(self.response.content).best()
|
||||||
|
if detected_encoding and detected_encoding.encoding:
|
||||||
|
try:
|
||||||
|
text = self.response.content.decode(detected_encoding.encoding)
|
||||||
|
self._cached_text = text
|
||||||
|
return text
|
||||||
|
except (UnicodeDecodeError, TypeError, LookupError):
|
||||||
|
# Fallback to httpx's encoding detection if charset_normalizer fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback to httpx's built-in encoding detection
|
||||||
|
text = self.response.text
|
||||||
|
self._cached_text = text
|
||||||
|
return text
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> bytes:
|
def content(self) -> bytes:
|
||||||
|
|||||||
@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "knowledge-retrieval"
|
type: str = "knowledge-retrieval"
|
||||||
query_variable_selector: list[str]
|
query_variable_selector: list[str] | None | str = None
|
||||||
|
query_attachment_selector: list[str] | None | str = None
|
||||||
dataset_ids: list[str]
|
dataset_ids: list[str]
|
||||||
retrieval_mode: Literal["single", "multiple"]
|
retrieval_mode: Literal["single", "multiple"]
|
||||||
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
||||||
|
|||||||
@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
|||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
|
ArrayFileSegment,
|
||||||
|
FileSegment,
|
||||||
StringSegment,
|
StringSegment,
|
||||||
)
|
)
|
||||||
from core.variables.segments import ArrayObjectSegment
|
from core.variables.segments import ArrayObjectSegment
|
||||||
@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# extract variables
|
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
|
||||||
if not isinstance(variable, StringSegment):
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs={},
|
inputs={},
|
||||||
error="Query variable is not string type.",
|
process_data={},
|
||||||
)
|
outputs={},
|
||||||
query = variable.value
|
metadata={},
|
||||||
variables = {"query": query}
|
llm_usage=LLMUsage.empty_usage(),
|
||||||
if not query:
|
|
||||||
return NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
|
||||||
)
|
)
|
||||||
|
variables: dict[str, Any] = {}
|
||||||
|
# extract variables
|
||||||
|
if self._node_data.query_variable_selector:
|
||||||
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||||
|
if not isinstance(variable, StringSegment):
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs={},
|
||||||
|
error="Query variable is not string type.",
|
||||||
|
)
|
||||||
|
query = variable.value
|
||||||
|
variables["query"] = query
|
||||||
|
|
||||||
|
if self._node_data.query_attachment_selector:
|
||||||
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
|
||||||
|
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs={},
|
||||||
|
error="Attachments variable is not array file or file type.",
|
||||||
|
)
|
||||||
|
if isinstance(variable, ArrayFileSegment):
|
||||||
|
variables["attachments"] = variable.value
|
||||||
|
else:
|
||||||
|
variables["attachments"] = [variable.value]
|
||||||
|
|
||||||
# TODO(-LAN-): Move this check outside.
|
# TODO(-LAN-): Move this check outside.
|
||||||
# check rate limit
|
# check rate limit
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||||
@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
try:
|
try:
|
||||||
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def _fetch_dataset_retriever(
|
def _fetch_dataset_retriever(
|
||||||
self, node_data: KnowledgeRetrievalNodeData, query: str
|
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
available_datasets = []
|
available_datasets = []
|
||||||
dataset_ids = node_data.dataset_ids
|
dataset_ids = node_data.dataset_ids
|
||||||
|
query = variables.get("query")
|
||||||
|
attachments = variables.get("attachments")
|
||||||
|
metadata_filter_document_ids = None
|
||||||
|
metadata_condition = None
|
||||||
|
metadata_usage = LLMUsage.empty_usage()
|
||||||
# Subquery: Count the number of available documents for each dataset
|
# Subquery: Count the number of available documents for each dataset
|
||||||
subquery = (
|
subquery = (
|
||||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||||
@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
continue
|
continue
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
if query:
|
||||||
[dataset.id for dataset in available_datasets], query, node_data
|
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||||
)
|
[dataset.id for dataset in available_datasets], query, node_data
|
||||||
usage = self._merge_usage(usage, metadata_usage)
|
)
|
||||||
|
usage = self._merge_usage(usage, metadata_usage)
|
||||||
all_documents = []
|
all_documents = []
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||||
# fetch model config
|
# fetch model config
|
||||||
if node_data.single_retrieval_config is None:
|
if node_data.single_retrieval_config is None:
|
||||||
raise ValueError("single_retrieval_config is required")
|
raise ValueError("single_retrieval_config is required")
|
||||||
@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||||
metadata_condition=metadata_condition,
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||||
if node_data.multiple_retrieval_config is None:
|
if node_data.multiple_retrieval_config is None:
|
||||||
raise ValueError("multiple_retrieval_config is required")
|
raise ValueError("multiple_retrieval_config is required")
|
||||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||||
@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||||
metadata_condition=metadata_condition,
|
metadata_condition=metadata_condition,
|
||||||
|
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||||
)
|
)
|
||||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||||
|
|
||||||
@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
retrieval_resource_list = []
|
retrieval_resource_list = []
|
||||||
# deal with external documents
|
# deal with external documents
|
||||||
for item in external_documents:
|
for item in external_documents:
|
||||||
source = {
|
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"_source": "knowledge",
|
"_source": "knowledge",
|
||||||
"dataset_id": item.metadata.get("dataset_id"),
|
"dataset_id": item.metadata.get("dataset_id"),
|
||||||
@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
"doc_metadata": document.doc_metadata,
|
"doc_metadata": document.doc_metadata,
|
||||||
},
|
},
|
||||||
"title": document.name,
|
"title": document.name,
|
||||||
|
"files": list(record.files) if record.files else None,
|
||||||
}
|
}
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||||
@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
if retrieval_resource_list:
|
if retrieval_resource_list:
|
||||||
retrieval_resource_list = sorted(
|
retrieval_resource_list = sorted(
|
||||||
retrieval_resource_list,
|
retrieval_resource_list,
|
||||||
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
|
key=self._score, # type: ignore[arg-type, return-value]
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||||
item["metadata"]["position"] = position
|
item["metadata"]["position"] = position # type: ignore[index]
|
||||||
return retrieval_resource_list, usage
|
return retrieval_resource_list, usage
|
||||||
|
|
||||||
|
def _score(self, item: dict[str, Any]) -> float:
|
||||||
|
meta = item.get("metadata")
|
||||||
|
if isinstance(meta, dict):
|
||||||
|
s = meta.get("score")
|
||||||
|
if isinstance(s, (int, float)):
|
||||||
|
return float(s)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def _get_metadata_filter_condition(
|
def _get_metadata_filter_condition(
|
||||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||||
@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||||
|
|
||||||
variable_mapping = {}
|
variable_mapping = {}
|
||||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
if typed_node_data.query_variable_selector:
|
||||||
|
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||||
|
if typed_node_data.query_attachment_selector:
|
||||||
|
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
|
|||||||
@ -7,8 +7,10 @@ import time
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.file import FileType, file_manager
|
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||||
@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
|
from core.tools.signature import sign_upload_file
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
ArrayFileSegment,
|
ArrayFileSegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector
|
|||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import SegmentAttachmentBinding
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
from . import llm_utils
|
from . import llm_utils
|
||||||
from .entities import (
|
from .entities import (
|
||||||
@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
# fetch context value
|
# fetch context value
|
||||||
generator = self._fetch_context(node_data=self.node_data)
|
generator = self._fetch_context(node_data=self.node_data)
|
||||||
context = None
|
context = None
|
||||||
|
context_files: list[File] = []
|
||||||
for event in generator:
|
for event in generator:
|
||||||
context = event.context
|
context = event.context
|
||||||
|
context_files = event.context_files or []
|
||||||
yield event
|
yield event
|
||||||
if context:
|
if context:
|
||||||
node_inputs["#context#"] = context
|
node_inputs["#context#"] = context
|
||||||
|
|
||||||
|
if context_files:
|
||||||
|
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||||
|
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = LLMNode._fetch_model_config(
|
model_instance, model_config = LLMNode._fetch_model_config(
|
||||||
node_data_model=self.node_data.model,
|
node_data_model=self.node_data.model,
|
||||||
@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
@ -654,10 +666,13 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
||||||
if context_value_variable:
|
if context_value_variable:
|
||||||
if isinstance(context_value_variable, StringSegment):
|
if isinstance(context_value_variable, StringSegment):
|
||||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
yield RunRetrieverResourceEvent(
|
||||||
|
retriever_resources=[], context=context_value_variable.value, context_files=[]
|
||||||
|
)
|
||||||
elif isinstance(context_value_variable, ArraySegment):
|
elif isinstance(context_value_variable, ArraySegment):
|
||||||
context_str = ""
|
context_str = ""
|
||||||
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||||
|
context_files: list[File] = []
|
||||||
for item in context_value_variable.value:
|
for item in context_value_variable.value:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
context_str += item + "\n"
|
context_str += item + "\n"
|
||||||
@ -670,9 +685,34 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||||
if retriever_resource:
|
if retriever_resource:
|
||||||
original_retriever_resource.append(retriever_resource)
|
original_retriever_resource.append(retriever_resource)
|
||||||
|
attachments_with_bindings = db.session.execute(
|
||||||
|
select(SegmentAttachmentBinding, UploadFile)
|
||||||
|
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||||
|
.where(
|
||||||
|
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
if attachments_with_bindings:
|
||||||
|
for _, upload_file in attachments_with_bindings:
|
||||||
|
attchment_info = File(
|
||||||
|
id=upload_file.id,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension="." + upload_file.extension,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
remote_url=upload_file.source_url,
|
||||||
|
related_id=upload_file.id,
|
||||||
|
size=upload_file.size,
|
||||||
|
storage_key=upload_file.key,
|
||||||
|
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||||
|
)
|
||||||
|
context_files.append(attchment_info)
|
||||||
yield RunRetrieverResourceEvent(
|
yield RunRetrieverResourceEvent(
|
||||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
retriever_resources=original_retriever_resource,
|
||||||
|
context=context_str.strip(),
|
||||||
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
|
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
|
||||||
@ -700,6 +740,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
content=context_dict.get("content"),
|
content=context_dict.get("content"),
|
||||||
page=metadata.get("page"),
|
page=metadata.get("page"),
|
||||||
doc_metadata=metadata.get("doc_metadata"),
|
doc_metadata=metadata.get("doc_metadata"),
|
||||||
|
files=context_dict.get("files"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return source
|
return source
|
||||||
@ -741,6 +782,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
context_files: list["File"] | None = None,
|
||||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
||||||
@ -853,6 +895,23 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
||||||
|
# The context_files
|
||||||
|
if vision_enabled and context_files:
|
||||||
|
file_prompts = []
|
||||||
|
for file in context_files:
|
||||||
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||||
|
file_prompts.append(file_prompt)
|
||||||
|
# If last prompt is a user prompt, add files into its contents,
|
||||||
|
# otherwise append a new user prompt
|
||||||
|
if (
|
||||||
|
len(prompt_messages) > 0
|
||||||
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||||
|
and isinstance(prompt_messages[-1].content, list)
|
||||||
|
):
|
||||||
|
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
||||||
# Remove empty messages and filter unsupported content
|
# Remove empty messages and filter unsupported content
|
||||||
filtered_prompt_messages = []
|
filtered_prompt_messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
|
|||||||
@ -97,11 +97,27 @@ dataset_detail_fields = {
|
|||||||
"total_documents": fields.Integer,
|
"total_documents": fields.Integer,
|
||||||
"total_available_documents": fields.Integer,
|
"total_available_documents": fields.Integer,
|
||||||
"enable_api": fields.Boolean,
|
"enable_api": fields.Boolean,
|
||||||
|
"is_multimodal": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
file_info_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"size": fields.Integer,
|
||||||
|
"extension": fields.String,
|
||||||
|
"mime_type": fields.String,
|
||||||
|
"source_url": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
content_fields = {
|
||||||
|
"content_type": fields.String,
|
||||||
|
"content": fields.String,
|
||||||
|
"file_info": fields.Nested(file_info_fields, allow_null=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_query_detail_fields = {
|
dataset_query_detail_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"content": fields.String,
|
"queries": fields.Nested(content_fields),
|
||||||
"source": fields.String,
|
"source": fields.String,
|
||||||
"source_app_id": fields.String,
|
"source_app_id": fields.String,
|
||||||
"created_by_role": fields.String,
|
"created_by_role": fields.String,
|
||||||
|
|||||||
@ -9,6 +9,8 @@ upload_config_fields = {
|
|||||||
"video_file_size_limit": fields.Integer,
|
"video_file_size_limit": fields.Integer,
|
||||||
"audio_file_size_limit": fields.Integer,
|
"audio_file_size_limit": fields.Integer,
|
||||||
"workflow_file_upload_limit": fields.Integer,
|
"workflow_file_upload_limit": fields.Integer,
|
||||||
|
"image_file_batch_limit": fields.Integer,
|
||||||
|
"single_chunk_attachment_limit": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -43,9 +43,19 @@ child_chunk_fields = {
|
|||||||
"score": fields.Float,
|
"score": fields.Float,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
files_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"size": fields.Integer,
|
||||||
|
"extension": fields.String,
|
||||||
|
"mime_type": fields.String,
|
||||||
|
"source_url": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
hit_testing_record_fields = {
|
hit_testing_record_fields = {
|
||||||
"segment": fields.Nested(segment_fields),
|
"segment": fields.Nested(segment_fields),
|
||||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||||
"score": fields.Float,
|
"score": fields.Float,
|
||||||
"tsne_position": fields.Raw,
|
"tsne_position": fields.Raw,
|
||||||
|
"files": fields.List(fields.Nested(files_fields)),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,6 +13,15 @@ child_chunk_fields = {
|
|||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
attachment_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"size": fields.Integer,
|
||||||
|
"extension": fields.String,
|
||||||
|
"mime_type": fields.String,
|
||||||
|
"source_url": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
segment_fields = {
|
segment_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"position": fields.Integer,
|
"position": fields.Integer,
|
||||||
@ -39,4 +48,5 @@ segment_fields = {
|
|||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
"stopped_at": TimestampField,
|
"stopped_at": TimestampField,
|
||||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||||
|
"attachments": fields.List(fields.Nested(attachment_fields)),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,57 @@
|
|||||||
|
"""support-multi-modal
|
||||||
|
|
||||||
|
Revision ID: d57accd375ae
|
||||||
|
Revises: 03f8dcbc611e
|
||||||
|
Create Date: 2025-11-12 15:37:12.363670
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'd57accd375ae'
|
||||||
|
down_revision = '7bb281b7a422'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('segment_attachment_bindings',
|
||||||
|
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('attachment_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.create_index(
|
||||||
|
'segment_attachment_binding_tenant_dataset_document_segment_idx',
|
||||||
|
['tenant_id', 'dataset_id', 'document_id', 'segment_id'],
|
||||||
|
unique=False
|
||||||
|
)
|
||||||
|
batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('is_multimodal')
|
||||||
|
|
||||||
|
|
||||||
|
with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('segment_attachment_binding_attachment_idx')
|
||||||
|
batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx')
|
||||||
|
|
||||||
|
op.drop_table('segment_attachment_bindings')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.tools.signature import sign_upload_file
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from libs.uuid_utils import uuidv7
|
from libs.uuid_utils import uuidv7
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||||
@ -76,6 +78,7 @@ class Dataset(Base):
|
|||||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||||
|
is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_documents(self):
|
def total_documents(self):
|
||||||
@ -728,9 +731,7 @@ class DocumentSegment(Base):
|
|||||||
created_by = mapped_column(StringUUID, nullable=False)
|
created_by = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = mapped_column(StringUUID, nullable=True)
|
updated_by = mapped_column(StringUUID, nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
|
||||||
)
|
|
||||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
error = mapped_column(LongText, nullable=True)
|
error = mapped_column(LongText, nullable=True)
|
||||||
@ -866,6 +867,47 @@ class DocumentSegment(Base):
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attachments(self) -> list[dict[str, Any]]:
|
||||||
|
# Use JOIN to fetch attachments in a single query instead of two separate queries
|
||||||
|
attachments_with_bindings = db.session.execute(
|
||||||
|
select(SegmentAttachmentBinding, UploadFile)
|
||||||
|
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||||
|
.where(
|
||||||
|
SegmentAttachmentBinding.tenant_id == self.tenant_id,
|
||||||
|
SegmentAttachmentBinding.dataset_id == self.dataset_id,
|
||||||
|
SegmentAttachmentBinding.document_id == self.document_id,
|
||||||
|
SegmentAttachmentBinding.segment_id == self.id,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
if not attachments_with_bindings:
|
||||||
|
return []
|
||||||
|
attachment_list = []
|
||||||
|
for _, attachment in attachments_with_bindings:
|
||||||
|
upload_file_id = attachment.id
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
reference_url = dify_config.CONSOLE_API_URL or ""
|
||||||
|
base_url = f"{reference_url}/files/{upload_file_id}/image-preview"
|
||||||
|
source_url = f"{base_url}?{params}"
|
||||||
|
attachment_list.append(
|
||||||
|
{
|
||||||
|
"id": attachment.id,
|
||||||
|
"name": attachment.name,
|
||||||
|
"size": attachment.size,
|
||||||
|
"extension": attachment.extension,
|
||||||
|
"mime_type": attachment.mime_type,
|
||||||
|
"source_url": source_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return attachment_list
|
||||||
|
|
||||||
|
|
||||||
class ChildChunk(Base):
|
class ChildChunk(Base):
|
||||||
__tablename__ = "child_chunks"
|
__tablename__ = "child_chunks"
|
||||||
@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase):
|
|||||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def queries(self) -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
queries = json.loads(self.content)
|
||||||
|
if isinstance(queries, list):
|
||||||
|
for query in queries:
|
||||||
|
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||||
|
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||||
|
if file_info:
|
||||||
|
query["file_info"] = {
|
||||||
|
"id": file_info.id,
|
||||||
|
"name": file_info.name,
|
||||||
|
"size": file_info.size,
|
||||||
|
"extension": file_info.extension,
|
||||||
|
"mime_type": file_info.mime_type,
|
||||||
|
"source_url": sign_upload_file(file_info.id, file_info.extension),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
query["file_info"] = None
|
||||||
|
|
||||||
|
return queries
|
||||||
|
else:
|
||||||
|
return [queries]
|
||||||
|
except JSONDecodeError:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"content_type": QueryType.TEXT_QUERY,
|
||||||
|
"content": self.content,
|
||||||
|
"file_info": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DatasetKeywordTable(TypeBase):
|
class DatasetKeywordTable(TypeBase):
|
||||||
__tablename__ = "dataset_keyword_tables"
|
__tablename__ = "dataset_keyword_tables"
|
||||||
@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase):
|
|||||||
onupdate=func.current_timestamp(),
|
onupdate=func.current_timestamp(),
|
||||||
init=False,
|
init=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentAttachmentBinding(Base):
|
||||||
|
__tablename__ = "segment_attachment_bindings"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
|
||||||
|
sa.Index(
|
||||||
|
"segment_attachment_binding_tenant_dataset_document_segment_idx",
|
||||||
|
"tenant_id",
|
||||||
|
"dataset_id",
|
||||||
|
"document_id",
|
||||||
|
"segment_id",
|
||||||
|
),
|
||||||
|
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
|
||||||
|
)
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|||||||
31
api/services/attachment_service.py
Normal file
31
api/services/attachment_service.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentService:
|
||||||
|
_session_maker: sessionmaker
|
||||||
|
|
||||||
|
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||||
|
if isinstance(session_factory, Engine):
|
||||||
|
self._session_maker = sessionmaker(bind=session_factory)
|
||||||
|
elif isinstance(session_factory, sessionmaker):
|
||||||
|
self._session_maker = session_factory
|
||||||
|
else:
|
||||||
|
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||||
|
|
||||||
|
def get_file_base64(self, file_id: str) -> str:
|
||||||
|
upload_file = (
|
||||||
|
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||||
|
)
|
||||||
|
if not upload_file:
|
||||||
|
raise NotFound("File not found")
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
return base64.b64encode(blob).decode()
|
||||||
@ -7,7 +7,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from redis.exceptions import LockNotOwnedError
|
from redis.exceptions import LockNotOwnedError
|
||||||
@ -19,9 +19,10 @@ from configs import dify_config
|
|||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.helper.name_generator import generate_incremental_name
|
from core.helper.name_generator import generate_incremental_name
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
@ -46,12 +47,14 @@ from models.dataset import (
|
|||||||
DocumentSegment,
|
DocumentSegment,
|
||||||
ExternalKnowledgeBindings,
|
ExternalKnowledgeBindings,
|
||||||
Pipeline,
|
Pipeline,
|
||||||
|
SegmentAttachmentBinding,
|
||||||
)
|
)
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
from models.source import DataSourceOauthBinding
|
from models.source import DataSourceOauthBinding
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||||
|
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
ChildChunkUpdateArgs,
|
ChildChunkUpdateArgs,
|
||||||
KnowledgeConfig,
|
KnowledgeConfig,
|
||||||
@ -82,7 +85,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
|||||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||||
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
||||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
|
||||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||||
@ -363,6 +365,27 @@ class DatasetService:
|
|||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ValueError(ex.description)
|
raise ValueError(ex.description)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
||||||
|
try:
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=model_provider,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
||||||
|
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError("Model schema not found")
|
||||||
|
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||||
try:
|
try:
|
||||||
@ -402,13 +425,13 @@ class DatasetService:
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
# check if dataset name is exists
|
# check if dataset name is exists
|
||||||
|
if data.get("name") and data.get("name") != dataset.name:
|
||||||
if DatasetService._has_dataset_same_name(
|
if DatasetService._has_dataset_same_name(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
name=data.get("name", dataset.name),
|
name=data.get("name", dataset.name),
|
||||||
):
|
):
|
||||||
raise ValueError("Dataset name already exists")
|
raise ValueError("Dataset name already exists")
|
||||||
|
|
||||||
# Verify user has permission to update this dataset
|
# Verify user has permission to update this dataset
|
||||||
DatasetService.check_dataset_permission(dataset, user)
|
DatasetService.check_dataset_permission(dataset, user)
|
||||||
@ -844,6 +867,12 @@ class DatasetService:
|
|||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=knowledge_configuration.embedding_model or "",
|
model=knowledge_configuration.embedding_model or "",
|
||||||
)
|
)
|
||||||
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
knowledge_configuration.embedding_model_provider,
|
||||||
|
knowledge_configuration.embedding_model,
|
||||||
|
)
|
||||||
|
dataset.is_multimodal = is_multimodal
|
||||||
dataset.embedding_model = embedding_model.model
|
dataset.embedding_model = embedding_model.model
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
@ -880,6 +909,12 @@ class DatasetService:
|
|||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
embedding_model.provider, embedding_model.model
|
||||||
)
|
)
|
||||||
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
knowledge_configuration.embedding_model_provider,
|
||||||
|
knowledge_configuration.embedding_model,
|
||||||
|
)
|
||||||
|
dataset.is_multimodal = is_multimodal
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
@ -937,6 +972,12 @@ class DatasetService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
knowledge_configuration.embedding_model_provider,
|
||||||
|
knowledge_configuration.embedding_model,
|
||||||
|
)
|
||||||
|
dataset.is_multimodal = is_multimodal
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
@ -1761,7 +1802,9 @@ class DocumentService:
|
|||||||
if document_ids:
|
if document_ids:
|
||||||
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
||||||
if duplicate_document_ids:
|
if duplicate_document_ids:
|
||||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
DuplicateDocumentIndexingTaskProxy(
|
||||||
|
dataset.tenant_id, dataset.id, duplicate_document_ids
|
||||||
|
).delay()
|
||||||
except LockNotOwnedError:
|
except LockNotOwnedError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -2303,6 +2346,7 @@ class DocumentService:
|
|||||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||||
collection_binding_id=dataset_collection_binding_id,
|
collection_binding_id=dataset_collection_binding_id,
|
||||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||||
|
is_multimodal=knowledge_config.is_multimodal,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(dataset)
|
db.session.add(dataset)
|
||||||
@ -2683,6 +2727,13 @@ class SegmentService:
|
|||||||
if "content" not in args or not args["content"] or not args["content"].strip():
|
if "content" not in args or not args["content"] or not args["content"].strip():
|
||||||
raise ValueError("Content is empty")
|
raise ValueError("Content is empty")
|
||||||
|
|
||||||
|
if args.get("attachment_ids"):
|
||||||
|
if not isinstance(args["attachment_ids"], list):
|
||||||
|
raise ValueError("Attachment IDs is invalid")
|
||||||
|
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
|
||||||
|
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
|
||||||
|
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
@ -2729,11 +2780,23 @@ class SegmentService:
|
|||||||
segment_document.word_count += len(args["answer"])
|
segment_document.word_count += len(args["answer"])
|
||||||
segment_document.answer = args["answer"]
|
segment_document.answer = args["answer"]
|
||||||
|
|
||||||
db.session.add(segment_document)
|
db.session.add(segment_document)
|
||||||
# update document word count
|
# update document word count
|
||||||
assert document.word_count is not None
|
assert document.word_count is not None
|
||||||
document.word_count += segment_document.word_count
|
document.word_count += segment_document.word_count
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
if args["attachment_ids"]:
|
||||||
|
for attachment_id in args["attachment_ids"]:
|
||||||
|
binding = SegmentAttachmentBinding(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
dataset_id=document.dataset_id,
|
||||||
|
document_id=document.id,
|
||||||
|
segment_id=segment_document.id,
|
||||||
|
attachment_id=attachment_id,
|
||||||
|
)
|
||||||
|
db.session.add(binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# save vector index
|
# save vector index
|
||||||
@ -2897,7 +2960,7 @@ class SegmentService:
|
|||||||
document.word_count = max(0, document.word_count + word_count_change)
|
document.word_count = max(0, document.word_count + word_count_change)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
# update segment index task
|
# update segment index task
|
||||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||||
# regenerate child chunks
|
# regenerate child chunks
|
||||||
# get embedding model instance
|
# get embedding model instance
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
@ -2924,12 +2987,11 @@ class SegmentService:
|
|||||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not processing_rule:
|
if processing_rule:
|
||||||
raise ValueError("No processing rule found.")
|
VectorService.generate_child_chunks(
|
||||||
VectorService.generate_child_chunks(
|
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
)
|
||||||
)
|
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
|
||||||
if args.enabled or keyword_changed:
|
if args.enabled or keyword_changed:
|
||||||
# update segment vector index
|
# update segment vector index
|
||||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||||
@ -2974,7 +3036,7 @@ class SegmentService:
|
|||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.add(segment)
|
db.session.add(segment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||||
# get embedding model instance
|
# get embedding model instance
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
@ -3000,15 +3062,15 @@ class SegmentService:
|
|||||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not processing_rule:
|
if processing_rule:
|
||||||
raise ValueError("No processing rule found.")
|
VectorService.generate_child_chunks(
|
||||||
VectorService.generate_child_chunks(
|
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
)
|
||||||
)
|
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
|
||||||
# update segment vector index
|
# update segment vector index
|
||||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||||
|
# update multimodel vector index
|
||||||
|
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("update segment index failed")
|
logger.exception("update segment index failed")
|
||||||
segment.enabled = False
|
segment.enabled = False
|
||||||
@ -3046,7 +3108,9 @@ class SegmentService:
|
|||||||
)
|
)
|
||||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||||
|
|
||||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
|
delete_segment_from_index_task.delay(
|
||||||
|
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
|
||||||
|
)
|
||||||
|
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
# update document word count
|
# update document word count
|
||||||
@ -3095,7 +3159,9 @@ class SegmentService:
|
|||||||
|
|
||||||
# Start async cleanup with both parent and child node IDs
|
# Start async cleanup with both parent and child node IDs
|
||||||
if index_node_ids or child_node_ids:
|
if index_node_ids or child_node_ids:
|
||||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
|
delete_segment_from_index_task.delay(
|
||||||
|
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
|
||||||
|
)
|
||||||
|
|
||||||
if document.word_count is None:
|
if document.word_count is None:
|
||||||
document.word_count = 0
|
document.word_count = 0
|
||||||
|
|||||||
@ -29,8 +29,14 @@ def get_current_user():
|
|||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
|
try:
|
||||||
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
|
user_object = current_user._get_current_object()
|
||||||
|
except AttributeError:
|
||||||
|
# Handle case where current_user might not be a LocalProxy in test environments
|
||||||
|
user_object = current_user
|
||||||
|
|
||||||
|
if not isinstance(user_object, (Account, EndUser)):
|
||||||
|
raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}")
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
api/services/document_indexing_proxy/__init__.py
Normal file
11
api/services/document_indexing_proxy/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from .base import DocumentTaskProxyBase
|
||||||
|
from .batch_indexing_base import BatchDocumentIndexingProxy
|
||||||
|
from .document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||||
|
from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BatchDocumentIndexingProxy",
|
||||||
|
"DocumentIndexingTaskProxy",
|
||||||
|
"DocumentTaskProxyBase",
|
||||||
|
"DuplicateDocumentIndexingTaskProxy",
|
||||||
|
]
|
||||||
111
api/services/document_indexing_proxy/base.py
Normal file
111
api/services/document_indexing_proxy/base.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from enums.cloud_plan import CloudPlan
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentTaskProxyBase(ABC):
|
||||||
|
"""
|
||||||
|
Base proxy for all document processing tasks.
|
||||||
|
|
||||||
|
Handles common logic:
|
||||||
|
- Feature/billing checks
|
||||||
|
- Dispatch routing based on plan
|
||||||
|
|
||||||
|
Subclasses must define:
|
||||||
|
- QUEUE_NAME: Redis queue identifier
|
||||||
|
- NORMAL_TASK_FUNC: Task function for normal priority
|
||||||
|
- PRIORITY_TASK_FUNC: Task function for high priority
|
||||||
|
"""
|
||||||
|
|
||||||
|
QUEUE_NAME: ClassVar[str]
|
||||||
|
NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]]
|
||||||
|
PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]]
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, dataset_id: str):
|
||||||
|
"""
|
||||||
|
Initialize with minimal required parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier for billing/features
|
||||||
|
dataset_id: Dataset identifier for logging
|
||||||
|
"""
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
self._dataset_id = dataset_id
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def features(self):
|
||||||
|
return FeatureService.get_features(self._tenant_id)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _send_to_direct_queue(self, task_func: Callable[..., Any]):
|
||||||
|
"""
|
||||||
|
Send task directly to Celery queue without tenant isolation.
|
||||||
|
|
||||||
|
Subclasses implement this to pass task-specific parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_func: The Celery task function to call
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _send_to_tenant_queue(self, task_func: Callable[..., Any]):
|
||||||
|
"""
|
||||||
|
Send task to tenant-isolated queue.
|
||||||
|
|
||||||
|
Subclasses implement this to handle queue management.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_func: The Celery task function to call
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _send_to_default_tenant_queue(self):
|
||||||
|
"""Route to normal priority with tenant isolation."""
|
||||||
|
self._send_to_tenant_queue(self.NORMAL_TASK_FUNC)
|
||||||
|
|
||||||
|
def _send_to_priority_tenant_queue(self):
|
||||||
|
"""Route to priority queue with tenant isolation."""
|
||||||
|
self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC)
|
||||||
|
|
||||||
|
def _send_to_priority_direct_queue(self):
|
||||||
|
"""Route to priority queue without tenant isolation."""
|
||||||
|
self._send_to_direct_queue(self.PRIORITY_TASK_FUNC)
|
||||||
|
|
||||||
|
def _dispatch(self):
|
||||||
|
"""
|
||||||
|
Dispatch task based on billing plan.
|
||||||
|
|
||||||
|
Routing logic:
|
||||||
|
- Sandbox plan → normal queue + tenant isolation
|
||||||
|
- Paid plans → priority queue + tenant isolation
|
||||||
|
- Self-hosted → priority queue, no isolation
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"dispatch args: %s - %s - %s",
|
||||||
|
self._tenant_id,
|
||||||
|
self.features.billing.enabled,
|
||||||
|
self.features.billing.subscription.plan,
|
||||||
|
)
|
||||||
|
# dispatch to different indexing queue with tenant isolation when billing enabled
|
||||||
|
if self.features.billing.enabled:
|
||||||
|
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||||
|
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
|
||||||
|
self._send_to_default_tenant_queue()
|
||||||
|
else:
|
||||||
|
# dispatch to priority pipeline queue with tenant self sub queue for other plans
|
||||||
|
self._send_to_priority_tenant_queue()
|
||||||
|
else:
|
||||||
|
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
||||||
|
self._send_to_priority_direct_queue()
|
||||||
|
|
||||||
|
def delay(self):
|
||||||
|
"""Public API: Queue the task asynchronously."""
|
||||||
|
self._dispatch()
|
||||||
76
api/services/document_indexing_proxy/batch_indexing_base.py
Normal file
76
api/services/document_indexing_proxy/batch_indexing_base.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.entities.document_task import DocumentTask
|
||||||
|
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||||
|
|
||||||
|
from .base import DocumentTaskProxyBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDocumentIndexingProxy(DocumentTaskProxyBase):
|
||||||
|
"""
|
||||||
|
Base proxy for batch document indexing tasks (document_ids in plural).
|
||||||
|
|
||||||
|
Adds:
|
||||||
|
- Tenant isolated queue management
|
||||||
|
- Batch document handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||||
|
"""
|
||||||
|
Initialize with batch documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier
|
||||||
|
dataset_id: Dataset identifier
|
||||||
|
document_ids: List of document IDs to process
|
||||||
|
"""
|
||||||
|
super().__init__(tenant_id, dataset_id)
|
||||||
|
self._document_ids = document_ids
|
||||||
|
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME)
|
||||||
|
|
||||||
|
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]):
|
||||||
|
"""
|
||||||
|
Send batch task to direct queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids)
|
||||||
|
"""
|
||||||
|
logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids)
|
||||||
|
task_func.delay( # type: ignore
|
||||||
|
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]):
|
||||||
|
"""
|
||||||
|
Send batch task to tenant-isolated queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids)
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME
|
||||||
|
)
|
||||||
|
if self._tenant_isolated_task_queue.get_task_key():
|
||||||
|
# Add to waiting queue using List operations (lpush)
|
||||||
|
self._tenant_isolated_task_queue.push_tasks(
|
||||||
|
[
|
||||||
|
asdict(
|
||||||
|
DocumentTask(
|
||||||
|
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids)
|
||||||
|
else:
|
||||||
|
# Set flag and execute task
|
||||||
|
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||||
|
task_func.delay( # type: ignore
|
||||||
|
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
||||||
|
)
|
||||||
|
logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids)
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||||
|
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
|
||||||
|
"""Proxy for document indexing tasks."""
|
||||||
|
|
||||||
|
QUEUE_NAME: ClassVar[str] = "document_indexing"
|
||||||
|
NORMAL_TASK_FUNC = normal_document_indexing_task
|
||||||
|
PRIORITY_TASK_FUNC = priority_document_indexing_task
|
||||||
@ -0,0 +1,15 @@
|
|||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||||
|
from tasks.duplicate_document_indexing_task import (
|
||||||
|
normal_duplicate_document_indexing_task,
|
||||||
|
priority_duplicate_document_indexing_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
|
||||||
|
"""Proxy for duplicate document indexing tasks."""
|
||||||
|
|
||||||
|
QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing"
|
||||||
|
NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task
|
||||||
|
PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task
|
||||||
@ -1,83 +0,0 @@
|
|||||||
import logging
|
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
from dataclasses import asdict
|
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
from core.entities.document_task import DocumentTask
|
|
||||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
|
||||||
from enums.cloud_plan import CloudPlan
|
|
||||||
from services.feature_service import FeatureService
|
|
||||||
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentIndexingTaskProxy:
|
|
||||||
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
|
||||||
self._tenant_id = tenant_id
|
|
||||||
self._dataset_id = dataset_id
|
|
||||||
self._document_ids = document_ids
|
|
||||||
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def features(self):
|
|
||||||
return FeatureService.get_features(self._tenant_id)
|
|
||||||
|
|
||||||
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
|
||||||
logger.info("send dataset %s to direct queue", self._dataset_id)
|
|
||||||
task_func.delay( # type: ignore
|
|
||||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
|
|
||||||
logger.info("send dataset %s to tenant queue", self._dataset_id)
|
|
||||||
if self._tenant_isolated_task_queue.get_task_key():
|
|
||||||
# Add to waiting queue using List operations (lpush)
|
|
||||||
self._tenant_isolated_task_queue.push_tasks(
|
|
||||||
[
|
|
||||||
asdict(
|
|
||||||
DocumentTask(
|
|
||||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
|
|
||||||
else:
|
|
||||||
# Set flag and execute task
|
|
||||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
|
||||||
task_func.delay( # type: ignore
|
|
||||||
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
|
|
||||||
)
|
|
||||||
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
|
|
||||||
|
|
||||||
def _send_to_default_tenant_queue(self):
|
|
||||||
self._send_to_tenant_queue(normal_document_indexing_task)
|
|
||||||
|
|
||||||
def _send_to_priority_tenant_queue(self):
|
|
||||||
self._send_to_tenant_queue(priority_document_indexing_task)
|
|
||||||
|
|
||||||
def _send_to_priority_direct_queue(self):
|
|
||||||
self._send_to_direct_queue(priority_document_indexing_task)
|
|
||||||
|
|
||||||
def _dispatch(self):
|
|
||||||
logger.info(
|
|
||||||
"dispatch args: %s - %s - %s",
|
|
||||||
self._tenant_id,
|
|
||||||
self.features.billing.enabled,
|
|
||||||
self.features.billing.subscription.plan,
|
|
||||||
)
|
|
||||||
# dispatch to different indexing queue with tenant isolation when billing enabled
|
|
||||||
if self.features.billing.enabled:
|
|
||||||
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
|
|
||||||
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
|
|
||||||
self._send_to_default_tenant_queue()
|
|
||||||
else:
|
|
||||||
# dispatch to priority pipeline queue with tenant self sub queue for other plans
|
|
||||||
self._send_to_priority_tenant_queue()
|
|
||||||
else:
|
|
||||||
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
|
|
||||||
self._send_to_priority_direct_queue()
|
|
||||||
|
|
||||||
def delay(self):
|
|
||||||
self._dispatch()
|
|
||||||
@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel):
|
|||||||
embedding_model: str | None = None
|
embedding_model: str | None = None
|
||||||
embedding_model_provider: str | None = None
|
embedding_model_provider: str | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
is_multimodal: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentCreateArgs(BaseModel):
|
||||||
|
content: str | None = None
|
||||||
|
answer: str | None = None
|
||||||
|
keywords: list[str] | None = None
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SegmentUpdateArgs(BaseModel):
|
class SegmentUpdateArgs(BaseModel):
|
||||||
@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel):
|
|||||||
keywords: list[str] | None = None
|
keywords: list[str] | None = None
|
||||||
regenerate_child_chunks: bool = False
|
regenerate_child_chunks: bool = False
|
||||||
enabled: bool | None = None
|
enabled: bool | None = None
|
||||||
|
attachment_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChildChunkUpdateArgs(BaseModel):
|
class ChildChunkUpdateArgs(BaseModel):
|
||||||
|
|||||||
@ -324,4 +324,5 @@ class ExternalDatasetService:
|
|||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return cast(list[Any], response.json().get("records", []))
|
return cast(list[Any], response.json().get("records", []))
|
||||||
return []
|
else:
|
||||||
|
raise ValueError(response.text)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@ -123,6 +124,15 @@ class FileService:
|
|||||||
|
|
||||||
return file_size <= file_size_limit
|
return file_size <= file_size_limit
|
||||||
|
|
||||||
|
def get_file_base64(self, file_id: str) -> str:
|
||||||
|
upload_file = (
|
||||||
|
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||||
|
)
|
||||||
|
if not upload_file:
|
||||||
|
raise NotFound("File not found")
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
return base64.b64encode(blob).decode()
|
||||||
|
|
||||||
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
|
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
|
||||||
if len(text_name) > 200:
|
if len(text_name) > 200:
|
||||||
text_name = text_name[:200]
|
text_name = text_name[:200]
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -5,6 +6,7 @@ from typing import Any
|
|||||||
from core.app.app_config.entities import ModelConfig
|
from core.app.app_config.entities import ModelConfig
|
||||||
from core.model_runtime.entities import LLMMode
|
from core.model_runtime.entities import LLMMode
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
@ -32,6 +34,7 @@ class HitTestingService:
|
|||||||
account: Account,
|
account: Account,
|
||||||
retrieval_model: Any, # FIXME drop this any
|
retrieval_model: Any, # FIXME drop this any
|
||||||
external_retrieval_model: dict,
|
external_retrieval_model: dict,
|
||||||
|
attachment_ids: list | None = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
):
|
):
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@ -41,7 +44,7 @@ class HitTestingService:
|
|||||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||||
document_ids_filter = None
|
document_ids_filter = None
|
||||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||||
if metadata_filtering_conditions:
|
if metadata_filtering_conditions and query:
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
|
||||||
from core.app.app_config.entities import MetadataFilteringCondition
|
from core.app.app_config.entities import MetadataFilteringCondition
|
||||||
@ -66,6 +69,7 @@ class HitTestingService:
|
|||||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
|
attachment_ids=attachment_ids,
|
||||||
top_k=retrieval_model.get("top_k", 4),
|
top_k=retrieval_model.get("top_k", 4),
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||||
if retrieval_model["score_threshold_enabled"]
|
if retrieval_model["score_threshold_enabled"]
|
||||||
@ -80,17 +84,24 @@ class HitTestingService:
|
|||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.debug("Hit testing retrieve in %s seconds", end - start)
|
logger.debug("Hit testing retrieve in %s seconds", end - start)
|
||||||
|
dataset_queries = []
|
||||||
dataset_query = DatasetQuery(
|
if query:
|
||||||
dataset_id=dataset.id,
|
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
|
||||||
content=query,
|
dataset_queries.append(content)
|
||||||
source="hit_testing",
|
if attachment_ids:
|
||||||
source_app_id=None,
|
for attachment_id in attachment_ids:
|
||||||
created_by_role="account",
|
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
|
||||||
created_by=account.id,
|
dataset_queries.append(content)
|
||||||
)
|
if dataset_queries:
|
||||||
|
dataset_query = DatasetQuery(
|
||||||
db.session.add(dataset_query)
|
dataset_id=dataset.id,
|
||||||
|
content=json.dumps(dataset_queries),
|
||||||
|
source="hit_testing",
|
||||||
|
source_app_id=None,
|
||||||
|
created_by_role="account",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return cls.compact_retrieve_response(query, all_documents)
|
return cls.compact_retrieve_response(query, all_documents)
|
||||||
@ -168,9 +179,14 @@ class HitTestingService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def hit_testing_args_check(cls, args):
|
def hit_testing_args_check(cls, args):
|
||||||
query = args["query"]
|
query = args["query"]
|
||||||
|
attachment_ids = args["attachment_ids"]
|
||||||
|
|
||||||
if not query or len(query) > 250:
|
if not attachment_ids and not query:
|
||||||
raise ValueError("Query is required and cannot exceed 250 characters")
|
raise ValueError("Query or attachment_ids is required")
|
||||||
|
if query and len(query) > 250:
|
||||||
|
raise ValueError("Query cannot exceed 250 characters")
|
||||||
|
if attachment_ids and not isinstance(attachment_ids, list):
|
||||||
|
raise ValueError("Attachment_ids must be a list")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def escape_query_for_search(query: str) -> str:
|
def escape_query_for_search(query: str) -> str:
|
||||||
|
|||||||
@ -38,21 +38,24 @@ class RagPipelineTaskProxy:
|
|||||||
upload_file = FileService(db.engine).upload_text(
|
upload_file = FileService(db.engine).upload_text(
|
||||||
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities)
|
||||||
|
)
|
||||||
return upload_file.id
|
return upload_file.id
|
||||||
|
|
||||||
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||||
logger.info("send file %s to direct queue", upload_file_id)
|
logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id)
|
||||||
task_func.delay( # type: ignore
|
task_func.delay( # type: ignore
|
||||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||||
tenant_id=self._dataset_tenant_id,
|
tenant_id=self._dataset_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
|
||||||
logger.info("send file %s to tenant queue", upload_file_id)
|
logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id)
|
||||||
if self._tenant_isolated_task_queue.get_task_key():
|
if self._tenant_isolated_task_queue.get_task_key():
|
||||||
# Add to waiting queue using List operations (lpush)
|
# Add to waiting queue using List operations (lpush)
|
||||||
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
|
||||||
logger.info("push tasks: %s", upload_file_id)
|
logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id)
|
||||||
else:
|
else:
|
||||||
# Set flag and execute task
|
# Set flag and execute task
|
||||||
self._tenant_isolated_task_queue.set_task_waiting_time()
|
self._tenant_isolated_task_queue.set_task_waiting_time()
|
||||||
@ -60,7 +63,7 @@ class RagPipelineTaskProxy:
|
|||||||
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
rag_pipeline_invoke_entities_file_id=upload_file_id,
|
||||||
tenant_id=self._dataset_tenant_id,
|
tenant_id=self._dataset_tenant_id,
|
||||||
)
|
)
|
||||||
logger.info("init tasks: %s", upload_file_id)
|
logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id)
|
||||||
|
|
||||||
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
def _send_to_default_tenant_queue(self, upload_file_id: str):
|
||||||
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
|
||||||
|
|||||||
@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import AttachmentDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
from models import UploadFile
|
||||||
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||||
|
|
||||||
@ -21,9 +24,10 @@ class VectorService:
|
|||||||
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||||
):
|
):
|
||||||
documents: list[Document] = []
|
documents: list[Document] = []
|
||||||
|
multimodal_documents: list[AttachmentDocument] = []
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||||
if not dataset_document:
|
if not dataset_document:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -70,12 +74,29 @@ class VectorService:
|
|||||||
"doc_hash": segment.index_node_hash,
|
"doc_hash": segment.index_node_hash,
|
||||||
"document_id": segment.document_id,
|
"document_id": segment.document_id,
|
||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.TEXT,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
documents.append(rag_document)
|
documents.append(rag_document)
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
for attachment in segment.attachments:
|
||||||
|
multimodal_document: AttachmentDocument = AttachmentDocument(
|
||||||
|
page_content=attachment["name"],
|
||||||
|
metadata={
|
||||||
|
"doc_id": attachment["id"],
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
multimodal_documents.append(multimodal_document)
|
||||||
|
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||||
|
|
||||||
if len(documents) > 0:
|
if len(documents) > 0:
|
||||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
|
||||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
if len(multimodal_documents) > 0:
|
||||||
|
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
|
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
|
||||||
@ -130,6 +151,7 @@ class VectorService:
|
|||||||
"doc_hash": segment.index_node_hash,
|
"doc_hash": segment.index_node_hash,
|
||||||
"document_id": segment.document_id,
|
"document_id": segment.document_id,
|
||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.TEXT,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# use full doc mode to generate segment's child chunk
|
# use full doc mode to generate segment's child chunk
|
||||||
@ -226,3 +248,92 @@ class VectorService:
|
|||||||
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
||||||
vector = Vector(dataset=dataset)
|
vector = Vector(dataset=dataset)
|
||||||
vector.delete_by_ids([child_chunk.index_node_id])
|
vector.delete_by_ids([child_chunk.index_node_id])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
|
||||||
|
if dataset.indexing_technique != "high_quality":
|
||||||
|
return
|
||||||
|
|
||||||
|
attachments = segment.attachments
|
||||||
|
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
|
||||||
|
|
||||||
|
# Check if there's any actual change needed
|
||||||
|
if set(attachment_ids) == set(old_attachment_ids):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
vector = Vector(dataset=dataset)
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
# Delete old vectors if they exist
|
||||||
|
if old_attachment_ids:
|
||||||
|
vector.delete_by_ids(old_attachment_ids)
|
||||||
|
|
||||||
|
# Delete existing segment attachment bindings in one operation
|
||||||
|
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
|
||||||
|
synchronize_session=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not attachment_ids:
|
||||||
|
db.session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Bulk fetch upload files - only fetch needed fields
|
||||||
|
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||||
|
|
||||||
|
if not upload_file_list:
|
||||||
|
db.session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a mapping for quick lookup
|
||||||
|
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
|
||||||
|
|
||||||
|
# Prepare batch operations
|
||||||
|
bindings = []
|
||||||
|
documents = []
|
||||||
|
|
||||||
|
# Create common metadata base to avoid repetition
|
||||||
|
base_metadata = {
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process attachments in the order specified by attachment_ids
|
||||||
|
for attachment_id in attachment_ids:
|
||||||
|
upload_file = upload_file_map.get(attachment_id)
|
||||||
|
if not upload_file:
|
||||||
|
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create segment attachment binding
|
||||||
|
bindings.append(
|
||||||
|
SegmentAttachmentBinding(
|
||||||
|
tenant_id=segment.tenant_id,
|
||||||
|
dataset_id=segment.dataset_id,
|
||||||
|
document_id=segment.document_id,
|
||||||
|
segment_id=segment.id,
|
||||||
|
attachment_id=upload_file.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create document for vector indexing
|
||||||
|
documents.append(
|
||||||
|
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bulk insert all bindings at once
|
||||||
|
if bindings:
|
||||||
|
db.session.add_all(bindings)
|
||||||
|
|
||||||
|
# Add documents to vector store if any
|
||||||
|
if documents and dataset.is_multimodal:
|
||||||
|
vector.add_texts(documents, duplicate_check=True)
|
||||||
|
|
||||||
|
# Single commit for all operations
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
|
||||||
|
db.session.rollback()
|
||||||
|
raise
|
||||||
|
|||||||
@ -4,9 +4,10 @@ import time
|
|||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
|
multimodal_documents = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=segment.content,
|
page_content=segment.content,
|
||||||
@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = segment.get_child_chunks()
|
child_chunks = segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||||||
)
|
)
|
||||||
child_documents.append(child_document)
|
child_documents.append(child_document)
|
||||||
document.children = child_documents
|
document.children = child_documents
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
for attachment in segment.attachments:
|
||||||
|
multimodal_documents.append(
|
||||||
|
AttachmentDocument(
|
||||||
|
page_content=attachment["name"],
|
||||||
|
metadata={
|
||||||
|
"doc_id": attachment["id"],
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
index_type = dataset.doc_form
|
index_type = dataset.doc_form
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
index_processor.load(dataset, documents)
|
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||||
|
|
||||||
# delete auto disable log
|
# delete auto disable log
|
||||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
|
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from models.dataset import (
|
|||||||
DatasetQuery,
|
DatasetQuery,
|
||||||
Document,
|
Document,
|
||||||
DocumentSegment,
|
DocumentSegment,
|
||||||
|
SegmentAttachmentBinding,
|
||||||
)
|
)
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
@ -58,14 +59,20 @@ def clean_dataset_task(
|
|||||||
)
|
)
|
||||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||||
|
# Use JOIN to fetch attachments with bindings in a single query
|
||||||
|
attachments_with_bindings = db.session.execute(
|
||||||
|
select(SegmentAttachmentBinding, UploadFile)
|
||||||
|
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||||
|
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
|
||||||
|
).all()
|
||||||
|
|
||||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||||
# This ensures all invalid doc_form values are properly handled
|
# This ensures all invalid doc_form values are properly handled
|
||||||
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
||||||
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
|
|
||||||
doc_form = IndexType.PARAGRAPH_INDEX
|
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||||
logger.info(
|
logger.info(
|
||||||
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
|
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||||
)
|
)
|
||||||
@ -90,6 +97,7 @@ def clean_dataset_task(
|
|||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
db.session.delete(document)
|
db.session.delete(document)
|
||||||
|
# delete document file
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||||
@ -107,6 +115,19 @@ def clean_dataset_task(
|
|||||||
)
|
)
|
||||||
db.session.delete(image_file)
|
db.session.delete(image_file)
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
|
# delete segment attachments
|
||||||
|
if attachments_with_bindings:
|
||||||
|
for binding, attachment_file in attachments_with_bindings:
|
||||||
|
try:
|
||||||
|
storage.delete(attachment_file.key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Delete attachment_file failed when storage deleted, \
|
||||||
|
attachment_file_id: %s",
|
||||||
|
binding.attachment_id,
|
||||||
|
)
|
||||||
|
db.session.delete(attachment_file)
|
||||||
|
db.session.delete(binding)
|
||||||
|
|
||||||
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||||
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||||||
raise Exception("Document has no dataset")
|
raise Exception("Document has no dataset")
|
||||||
|
|
||||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||||
|
# Use JOIN to fetch attachments with bindings in a single query
|
||||||
|
attachments_with_bindings = db.session.execute(
|
||||||
|
select(SegmentAttachmentBinding, UploadFile)
|
||||||
|
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||||
|
.where(
|
||||||
|
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
|
||||||
|
SegmentAttachmentBinding.dataset_id == dataset_id,
|
||||||
|
SegmentAttachmentBinding.document_id == document_id,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
# check segment is exist
|
# check segment is exist
|
||||||
if segments:
|
if segments:
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||||
db.session.delete(file)
|
db.session.delete(file)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
# delete segment attachments
|
||||||
|
if attachments_with_bindings:
|
||||||
|
for binding, attachment_file in attachments_with_bindings:
|
||||||
|
try:
|
||||||
|
storage.delete(attachment_file.key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Delete attachment_file failed when storage deleted, \
|
||||||
|
attachment_file_id: %s",
|
||||||
|
binding.attachment_id,
|
||||||
|
)
|
||||||
|
db.session.delete(attachment_file)
|
||||||
|
db.session.delete(binding)
|
||||||
|
|
||||||
# delete dataset metadata binding
|
# delete dataset metadata binding
|
||||||
db.session.query(DatasetMetadataBinding).where(
|
db.session.query(DatasetMetadataBinding).where(
|
||||||
|
|||||||
@ -4,9 +4,10 @@ import time
|
|||||||
import click
|
import click
|
||||||
from celery import shared_task # type: ignore
|
from celery import shared_task # type: ignore
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception("Dataset not found")
|
raise Exception("Dataset not found")
|
||||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
if action == "upgrade":
|
if action == "upgrade":
|
||||||
dataset_documents = (
|
dataset_documents = (
|
||||||
@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||||||
)
|
)
|
||||||
if segments:
|
if segments:
|
||||||
documents = []
|
documents = []
|
||||||
|
multimodal_documents = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=segment.content,
|
page_content=segment.content,
|
||||||
@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = segment.get_child_chunks()
|
child_chunks = segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||||||
)
|
)
|
||||||
child_documents.append(child_document)
|
child_documents.append(child_document)
|
||||||
document.children = child_documents
|
document.children = child_documents
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
for attachment in segment.attachments:
|
||||||
|
multimodal_documents.append(
|
||||||
|
AttachmentDocument(
|
||||||
|
page_content=attachment["name"],
|
||||||
|
metadata={
|
||||||
|
"doc_id": attachment["id"],
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
# save vector index
|
# save vector index
|
||||||
index_processor.load(dataset, documents, with_keywords=False)
|
index_processor.load(
|
||||||
|
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||||
|
)
|
||||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||||
{"indexing_status": "completed"}, synchronize_session=False
|
{"indexing_status": "completed"}, synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@shared_task(queue="dataset")
|
@shared_task(queue="dataset")
|
||||||
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
|
def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||||
"""
|
"""
|
||||||
Async deal dataset from index
|
Async deal dataset from index
|
||||||
:param dataset_id: dataset_id
|
:param dataset_id: dataset_id
|
||||||
@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
|||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception("Dataset not found")
|
raise Exception("Dataset not found")
|
||||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
if action == "remove":
|
if action == "remove":
|
||||||
index_processor.clean(dataset, None, with_keywords=False)
|
index_processor.clean(dataset, None, with_keywords=False)
|
||||||
@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
|||||||
)
|
)
|
||||||
if segments:
|
if segments:
|
||||||
documents = []
|
documents = []
|
||||||
|
multimodal_documents = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=segment.content,
|
page_content=segment.content,
|
||||||
@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
|||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = segment.get_child_chunks()
|
child_chunks = segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
|||||||
)
|
)
|
||||||
child_documents.append(child_document)
|
child_documents.append(child_document)
|
||||||
document.children = child_documents
|
document.children = child_documents
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
for attachment in segment.attachments:
|
||||||
|
multimodal_documents.append(
|
||||||
|
AttachmentDocument(
|
||||||
|
page_content=attachment["name"],
|
||||||
|
metadata={
|
||||||
|
"doc_id": attachment["id"],
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
# save vector index
|
# save vector index
|
||||||
index_processor.load(dataset, documents, with_keywords=False)
|
index_processor.load(
|
||||||
|
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||||
|
)
|
||||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||||
{"indexing_status": "completed"}, synchronize_session=False
|
{"indexing_status": "completed"}, synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,14 +6,15 @@ from celery import shared_task
|
|||||||
|
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, Document
|
from models.dataset import Dataset, Document, SegmentAttachmentBinding
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="dataset")
|
@shared_task(queue="dataset")
|
||||||
def delete_segment_from_index_task(
|
def delete_segment_from_index_task(
|
||||||
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
|
index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Async Remove segment from index
|
Async Remove segment from index
|
||||||
@ -49,6 +50,21 @@ def delete_segment_from_index_task(
|
|||||||
delete_child_chunks=True,
|
delete_child_chunks=True,
|
||||||
precomputed_child_node_ids=child_node_ids,
|
precomputed_child_node_ids=child_node_ids,
|
||||||
)
|
)
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
# delete segment attachment binding
|
||||||
|
segment_attachment_bindings = (
|
||||||
|
db.session.query(SegmentAttachmentBinding)
|
||||||
|
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if segment_attachment_bindings:
|
||||||
|
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||||
|
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||||
|
for binding in segment_attachment_bindings:
|
||||||
|
db.session.delete(binding)
|
||||||
|
# delete upload file
|
||||||
|
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from sqlalchemy import select
|
|||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
segment_attachment_bindings = (
|
||||||
|
db.session.query(SegmentAttachmentBinding)
|
||||||
|
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if segment_attachment_bindings:
|
||||||
|
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||||
|
index_node_ids.extend(attachment_ids)
|
||||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
|
|||||||
@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue(
|
|||||||
try:
|
try:
|
||||||
_document_indexing(dataset_id, document_ids)
|
_document_indexing(dataset_id, document_ids)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
|
logger.exception(
|
||||||
|
"Error processing document indexing %s for tenant %s: %s",
|
||||||
|
dataset_id,
|
||||||
|
tenant_id,
|
||||||
|
document_ids,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
|
||||||
|
|
||||||
@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue(
|
|||||||
# Use rpop to get the next task from the queue (FIFO order)
|
# Use rpop to get the next task from the queue (FIFO order)
|
||||||
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||||
|
|
||||||
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
|
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||||
|
|
||||||
if next_tasks:
|
if next_tasks:
|
||||||
for next_task in next_tasks:
|
for next_task in next_tasks:
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.entities.document_task import DocumentTask
|
||||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
|
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||||||
:param dataset_id:
|
:param dataset_id:
|
||||||
:param document_ids:
|
:param document_ids:
|
||||||
|
|
||||||
|
.. warning:: TO BE DEPRECATED
|
||||||
|
This function will be deprecated and removed in a future version.
|
||||||
|
Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead.
|
||||||
|
|
||||||
Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids)
|
Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids)
|
||||||
"""
|
"""
|
||||||
|
logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids)
|
||||||
|
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def _duplicate_document_indexing_task_with_tenant_queue(
|
||||||
|
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Error processing duplicate document indexing %s for tenant %s: %s",
|
||||||
|
dataset_id,
|
||||||
|
tenant_id,
|
||||||
|
document_ids,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing")
|
||||||
|
|
||||||
|
# Check if there are waiting tasks in the queue
|
||||||
|
# Use rpop to get the next task from the queue (FIFO order)
|
||||||
|
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||||
|
|
||||||
|
logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||||
|
|
||||||
|
if next_tasks:
|
||||||
|
for next_task in next_tasks:
|
||||||
|
document_task = DocumentTask(**next_task)
|
||||||
|
# Process the next waiting task
|
||||||
|
# Keep the flag set to indicate a task is running
|
||||||
|
tenant_isolated_task_queue.set_task_waiting_time()
|
||||||
|
task_func.delay( # type: ignore
|
||||||
|
tenant_id=document_task.tenant_id,
|
||||||
|
dataset_id=document_task.dataset_id,
|
||||||
|
document_ids=document_task.document_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No more waiting tasks, clear the flag
|
||||||
|
tenant_isolated_task_queue.delete_task_key()
|
||||||
|
|
||||||
|
|
||||||
|
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
|
||||||
documents = []
|
documents = []
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||||||
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
|
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
|
||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue="dataset")
|
||||||
|
def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||||
|
"""
|
||||||
|
Async process duplicate documents
|
||||||
|
:param tenant_id:
|
||||||
|
:param dataset_id:
|
||||||
|
:param document_ids:
|
||||||
|
|
||||||
|
Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||||
|
"""
|
||||||
|
logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||||
|
_duplicate_document_indexing_task_with_tenant_queue(
|
||||||
|
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue="priority_dataset")
|
||||||
|
def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
|
||||||
|
"""
|
||||||
|
Async process duplicate documents
|
||||||
|
:param tenant_id:
|
||||||
|
:param dataset_id:
|
||||||
|
:param document_ids:
|
||||||
|
|
||||||
|
Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
|
||||||
|
"""
|
||||||
|
logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
|
||||||
|
_duplicate_document_indexing_task_with_tenant_queue(
|
||||||
|
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||||
|
)
|
||||||
|
|||||||
@ -4,9 +4,10 @@ import time
|
|||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = segment.get_child_chunks()
|
child_chunks = segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str):
|
|||||||
)
|
)
|
||||||
child_documents.append(child_document)
|
child_documents.append(child_document)
|
||||||
document.children = child_documents
|
document.children = child_documents
|
||||||
|
multimodel_documents = []
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
for attachment in segment.attachments:
|
||||||
|
multimodel_documents.append(
|
||||||
|
AttachmentDocument(
|
||||||
|
page_content=attachment["name"],
|
||||||
|
metadata={
|
||||||
|
"doc_id": attachment["id"],
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# save vector index
|
# save vector index
|
||||||
index_processor.load(dataset, [document])
|
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||||
|
|||||||
@ -5,9 +5,10 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
documents = []
|
documents = []
|
||||||
|
multimodal_documents = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=segment.content,
|
page_content=segment.content,
|
||||||
@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
child_chunks = segment.get_child_chunks()
|
child_chunks = segment.get_child_chunks()
|
||||||
if child_chunks:
|
if child_chunks:
|
||||||
child_documents = []
|
child_documents = []
|
||||||
@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||||||
)
|
)
|
||||||
child_documents.append(child_document)
|
child_documents.append(child_document)
|
||||||
document.children = child_documents
|
document.children = child_documents
|
||||||
|
|
||||||
|
if dataset.is_multimodal:
|
||||||
|
for attachment in segment.attachments:
|
||||||
|
multimodal_documents.append(
|
||||||
|
AttachmentDocument(
|
||||||
|
page_content=attachment["name"],
|
||||||
|
metadata={
|
||||||
|
"doc_id": attachment["id"],
|
||||||
|
"doc_hash": "",
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
"doc_type": DocType.IMAGE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
# save vector index
|
# save vector index
|
||||||
index_processor.load(dataset, documents)
|
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||||
|
|||||||
@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task(
|
|||||||
)
|
)
|
||||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||||
|
|
||||||
|
logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities))
|
||||||
|
|
||||||
# Get Flask app object for thread context
|
# Get Flask app object for thread context
|
||||||
flask_app = current_app._get_current_object() # type: ignore
|
flask_app = current_app._get_current_object() # type: ignore
|
||||||
|
|
||||||
@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task(
|
|||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
click.style(
|
click.style(
|
||||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task(
|
|||||||
# Check if there are waiting tasks in the queue
|
# Check if there are waiting tasks in the queue
|
||||||
# Use rpop to get the next task from the queue (FIFO order)
|
# Use rpop to get the next task from the queue (FIFO order)
|
||||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||||
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||||
|
|
||||||
if next_file_ids:
|
if next_file_ids:
|
||||||
for next_file_id in next_file_ids:
|
for next_file_id in next_file_ids:
|
||||||
|
|||||||
@ -47,6 +47,8 @@ def rag_pipeline_run_task(
|
|||||||
)
|
)
|
||||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||||
|
|
||||||
|
logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities))
|
||||||
|
|
||||||
# Get Flask app object for thread context
|
# Get Flask app object for thread context
|
||||||
flask_app = current_app._get_current_object() # type: ignore
|
flask_app = current_app._get_current_object() # type: ignore
|
||||||
|
|
||||||
@ -66,7 +68,7 @@ def rag_pipeline_run_task(
|
|||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
click.style(
|
click.style(
|
||||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -78,7 +80,7 @@ def rag_pipeline_run_task(
|
|||||||
# Check if there are waiting tasks in the queue
|
# Check if there are waiting tasks in the queue
|
||||||
# Use rpop to get the next task from the queue (FIFO order)
|
# Use rpop to get the next task from the queue (FIFO order)
|
||||||
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
|
||||||
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
|
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||||
|
|
||||||
if next_file_ids:
|
if next_file_ids:
|
||||||
for next_file_id in next_file_ids:
|
for next_file_id in next_file_ids:
|
||||||
|
|||||||
@ -0,0 +1,244 @@
|
|||||||
|
"""Integration tests for Trigger Provider subscription permission verification."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask.testing import FlaskClient
|
||||||
|
|
||||||
|
from controllers.console.workspace import trigger_providers as trigger_providers_api
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from models import Tenant
|
||||||
|
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||||
|
|
||||||
|
|
||||||
|
class TestTriggerProviderSubscriptionPermissions:
|
||||||
|
"""Test permission verification for Trigger Provider subscription endpoints."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Create a mock Account for testing."""
|
||||||
|
|
||||||
|
account = Account(name="Test User", email="test@example.com")
|
||||||
|
account.id = str(uuid.uuid4())
|
||||||
|
account.last_active_at = naive_utc_now()
|
||||||
|
account.created_at = naive_utc_now()
|
||||||
|
account.updated_at = naive_utc_now()
|
||||||
|
|
||||||
|
# Create mock tenant
|
||||||
|
tenant = Tenant(name="Test Tenant")
|
||||||
|
tenant.id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_session_instance = mock.Mock()
|
||||||
|
|
||||||
|
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||||
|
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||||
|
|
||||||
|
mock_scalars_result = mock.Mock()
|
||||||
|
mock_scalars_result.one.return_value = tenant
|
||||||
|
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||||
|
|
||||||
|
mock_session_context = mock.Mock()
|
||||||
|
mock_session_context.__enter__.return_value = mock_session_instance
|
||||||
|
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||||
|
|
||||||
|
account.current_tenant = tenant
|
||||||
|
account.current_tenant_id = tenant.id
|
||||||
|
return account
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"),
|
||||||
|
[
|
||||||
|
# Admin/Owner can do everything
|
||||||
|
(TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200),
|
||||||
|
(TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200),
|
||||||
|
# Editor can list, get, update (parameters), but not create, build, or delete
|
||||||
|
(TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403),
|
||||||
|
# Normal user cannot do anything
|
||||||
|
(TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403),
|
||||||
|
# Dataset operator cannot do anything
|
||||||
|
(TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_trigger_subscription_permissions(
|
||||||
|
self,
|
||||||
|
test_client: FlaskClient,
|
||||||
|
auth_header,
|
||||||
|
monkeypatch,
|
||||||
|
mock_account,
|
||||||
|
role: TenantAccountRole,
|
||||||
|
list_status: int,
|
||||||
|
get_status: int,
|
||||||
|
update_status: int,
|
||||||
|
create_status: int,
|
||||||
|
build_status: int,
|
||||||
|
delete_status: int,
|
||||||
|
):
|
||||||
|
"""Test that different roles have appropriate permissions for trigger subscription operations."""
|
||||||
|
# Set user role
|
||||||
|
mock_account.role = role
|
||||||
|
|
||||||
|
# Mock current user
|
||||||
|
monkeypatch.setattr(trigger_providers_api, "current_user", mock_account)
|
||||||
|
|
||||||
|
# Mock AccountService.load_user to prevent authentication issues
|
||||||
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
mock_load_user = mock.Mock(return_value=mock_account)
|
||||||
|
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
provider = "some_provider/some_trigger"
|
||||||
|
subscription_builder_id = str(uuid.uuid4())
|
||||||
|
subscription_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Mock service methods
|
||||||
|
mock_list_subscriptions = mock.Mock(return_value=[])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||||
|
mock_list_subscriptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
|
||||||
|
mock_get_subscription_builder,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
|
||||||
|
mock_update_subscription_builder,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||||
|
mock_create_subscription_builder,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_update_and_build_builder = mock.Mock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder",
|
||||||
|
mock_update_and_build_builder,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_delete_provider = mock.Mock()
|
||||||
|
mock_delete_plugin_trigger = mock.Mock()
|
||||||
|
mock_db_session = mock.Mock()
|
||||||
|
mock_db_session.commit = mock.Mock()
|
||||||
|
|
||||||
|
def mock_session_func(engine=None):
|
||||||
|
return mock_session_context
|
||||||
|
|
||||||
|
mock_session_context = mock.Mock()
|
||||||
|
mock_session_context.__enter__.return_value = mock_db_session
|
||||||
|
mock_session_context.__exit__.return_value = None
|
||||||
|
|
||||||
|
monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func)
|
||||||
|
monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider",
|
||||||
|
mock_delete_provider,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription",
|
||||||
|
mock_delete_plugin_trigger,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 1: List subscriptions (should work for Editor, Admin, Owner)
|
||||||
|
response = test_client.get(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list",
|
||||||
|
headers=auth_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == list_status
|
||||||
|
|
||||||
|
# Test 2: Get subscription builder (should work for Editor, Admin, Owner)
|
||||||
|
response = test_client.get(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}",
|
||||||
|
headers=auth_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == get_status
|
||||||
|
|
||||||
|
# Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner)
|
||||||
|
response = test_client.post(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}",
|
||||||
|
headers=auth_header,
|
||||||
|
json={"parameters": {"webhook_url": "https://example.com/webhook"}},
|
||||||
|
)
|
||||||
|
assert response.status_code == update_status
|
||||||
|
|
||||||
|
# Test 4: Create subscription builder (should only work for Admin, Owner)
|
||||||
|
response = test_client.post(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create",
|
||||||
|
headers=auth_header,
|
||||||
|
json={"credential_type": "api_key"},
|
||||||
|
)
|
||||||
|
assert response.status_code == create_status
|
||||||
|
|
||||||
|
# Test 5: Build/activate subscription (should only work for Admin, Owner)
|
||||||
|
response = test_client.post(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}",
|
||||||
|
headers=auth_header,
|
||||||
|
json={"name": "Test Subscription"},
|
||||||
|
)
|
||||||
|
assert response.status_code == build_status
|
||||||
|
|
||||||
|
# Test 6: Delete subscription (should only work for Admin, Owner)
|
||||||
|
response = test_client.post(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete",
|
||||||
|
headers=auth_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == delete_status
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("role", "status"),
|
||||||
|
[
|
||||||
|
(TenantAccountRole.OWNER, 200),
|
||||||
|
(TenantAccountRole.ADMIN, 200),
|
||||||
|
# Editor should be able to access logs for debugging
|
||||||
|
(TenantAccountRole.EDITOR, 200),
|
||||||
|
(TenantAccountRole.NORMAL, 403),
|
||||||
|
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_trigger_subscription_logs_permissions(
|
||||||
|
self,
|
||||||
|
test_client: FlaskClient,
|
||||||
|
auth_header,
|
||||||
|
monkeypatch,
|
||||||
|
mock_account,
|
||||||
|
role: TenantAccountRole,
|
||||||
|
status: int,
|
||||||
|
):
|
||||||
|
"""Test that different roles have appropriate permissions for accessing subscription logs."""
|
||||||
|
# Set user role
|
||||||
|
mock_account.role = role
|
||||||
|
|
||||||
|
# Mock current user
|
||||||
|
monkeypatch.setattr(trigger_providers_api, "current_user", mock_account)
|
||||||
|
|
||||||
|
# Mock AccountService.load_user to prevent authentication issues
|
||||||
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
mock_load_user = mock.Mock(return_value=mock_account)
|
||||||
|
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
provider = "some_provider/some_trigger"
|
||||||
|
subscription_builder_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Mock service method
|
||||||
|
mock_list_logs = mock.Mock(return_value=[])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs",
|
||||||
|
mock_list_logs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test access to logs
|
||||||
|
response = test_client.get(
|
||||||
|
f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}",
|
||||||
|
headers=auth_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == status
|
||||||
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask:
|
|||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
indexing_status="completed",
|
indexing_status="completed",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
doc_form=IndexType.PARAGRAPH_INDEX,
|
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||||
)
|
)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask:
|
|||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
# Verify index processor was called correctly
|
# Verify index processor was called correctly
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.PARAGRAPH_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
# Verify database state changes
|
# Verify database state changes
|
||||||
@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update document to use different index type
|
# Update document to use different index type
|
||||||
document.doc_form = IndexType.QA_INDEX
|
document.doc_form = IndexStructureType.QA_INDEX
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||||
@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask:
|
|||||||
add_document_to_index_task(document.id)
|
add_document_to_index_task(document.id)
|
||||||
|
|
||||||
# Assert: Verify different index type handling
|
# Assert: Verify different index type handling
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.QA_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
# Verify the load method was called with correct parameters
|
# Verify the load method was called with correct parameters
|
||||||
@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update document to use parent-child index type
|
# Update document to use parent-child index type
|
||||||
document.doc_form = IndexType.PARENT_CHILD_INDEX
|
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||||
@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask:
|
|||||||
|
|
||||||
# Assert: Verify parent-child index processing
|
# Assert: Verify parent-child index processing
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
IndexType.PARENT_CHILD_INDEX
|
IndexStructureType.PARENT_CHILD_INDEX
|
||||||
)
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask:
|
|||||||
# Act: Execute the task
|
# Act: Execute the task
|
||||||
add_document_to_index_task(document.id)
|
add_document_to_index_task(document.id)
|
||||||
|
|
||||||
# Assert: Verify index processing occurred with all completed segments
|
# Assert: Verify index processing occurred but with empty documents list
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.PARAGRAPH_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
# Verify the load method was called with all completed segments
|
# Verify the load method was called with all completed segments
|
||||||
@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask:
|
|||||||
assert len(remaining_logs) == 0
|
assert len(remaining_logs) == 0
|
||||||
|
|
||||||
# Verify index processing occurred normally
|
# Verify index processing occurred normally
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.PARAGRAPH_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
# Verify segments were enabled
|
# Verify segments were enabled
|
||||||
@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask:
|
|||||||
add_document_to_index_task(document.id)
|
add_document_to_index_task(document.id)
|
||||||
|
|
||||||
# Assert: Verify only eligible segments were processed
|
# Assert: Verify only eligible segments were processed
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.PARAGRAPH_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
# Verify the load method was called with correct parameters
|
# Verify the load method was called with correct parameters
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from models import Account, Dataset, Document, DocumentSegment, Tenant
|
from models import Account, Dataset, Document, DocumentSegment, Tenant
|
||||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
document.updated_at = fake.date_time_this_year()
|
document.updated_at = fake.date_time_this_year()
|
||||||
document.doc_type = kwargs.get("doc_type", "text")
|
document.doc_type = kwargs.get("doc_type", "text")
|
||||||
document.doc_metadata = kwargs.get("doc_metadata", {})
|
document.doc_metadata = kwargs.get("doc_metadata", {})
|
||||||
document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX)
|
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
|
||||||
document.doc_language = kwargs.get("doc_language", "en")
|
document.doc_language = kwargs.get("doc_language", "en")
|
||||||
|
|
||||||
db_session_with_containers.add(document)
|
db_session_with_containers.add(document)
|
||||||
@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
mock_processor = MagicMock()
|
mock_processor = MagicMock()
|
||||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
|
||||||
|
# Extract segment IDs for the task
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed successfully
|
# Verify the task completed successfully
|
||||||
assert result is None # Task should return None on success
|
assert result is None # Task should return None on success
|
||||||
@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
|
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
|
||||||
|
|
||||||
# Execute the task with non-existent dataset
|
# Execute the task with non-existent dataset
|
||||||
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id)
|
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, [])
|
||||||
|
|
||||||
# Verify the task completed without exceptions
|
# Verify the task completed without exceptions
|
||||||
assert result is None # Task should return None when dataset not found
|
assert result is None # Task should return None when dataset not found
|
||||||
@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
|
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
|
||||||
|
|
||||||
# Execute the task with non-existent document
|
# Execute the task with non-existent document
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, [])
|
||||||
|
|
||||||
# Verify the task completed without exceptions
|
# Verify the task completed without exceptions
|
||||||
assert result is None # Task should return None when document not found
|
assert result is None # Task should return None when document not found
|
||||||
@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||||
|
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Execute the task with disabled document
|
# Execute the task with disabled document
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed without exceptions
|
# Verify the task completed without exceptions
|
||||||
assert result is None # Task should return None when document is disabled
|
assert result is None # Task should return None when document is disabled
|
||||||
@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||||
|
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Execute the task with archived document
|
# Execute the task with archived document
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed without exceptions
|
# Verify the task completed without exceptions
|
||||||
assert result is None # Task should return None when document is archived
|
assert result is None # Task should return None when document is archived
|
||||||
@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||||
|
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Execute the task with incomplete indexing
|
# Execute the task with incomplete indexing
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed without exceptions
|
# Verify the task completed without exceptions
|
||||||
assert result is None # Task should return None when indexing is not completed
|
assert result is None # Task should return None when indexing is not completed
|
||||||
@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
fake = Faker()
|
fake = Faker()
|
||||||
|
|
||||||
# Test different document forms
|
# Test different document forms
|
||||||
document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX]
|
document_forms = [
|
||||||
|
IndexStructureType.PARAGRAPH_INDEX,
|
||||||
|
IndexStructureType.QA_INDEX,
|
||||||
|
IndexStructureType.PARENT_CHILD_INDEX,
|
||||||
|
]
|
||||||
|
|
||||||
for doc_form in document_forms:
|
for doc_form in document_forms:
|
||||||
# Create test data for each document form
|
# Create test data for each document form
|
||||||
@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
|
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
|
||||||
|
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Mock the index processor
|
# Mock the index processor
|
||||||
mock_processor = MagicMock()
|
mock_processor = MagicMock()
|
||||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed successfully
|
# Verify the task completed successfully
|
||||||
assert result is None
|
assert result is None
|
||||||
@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||||
|
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Mock the index processor to raise an exception
|
# Mock the index processor to raise an exception
|
||||||
mock_processor = MagicMock()
|
mock_processor = MagicMock()
|
||||||
@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
|
||||||
# Execute the task - should not raise exception
|
# Execute the task - should not raise exception
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed without raising exceptions
|
# Verify the task completed without raising exceptions
|
||||||
assert result is None # Task should return None even when exceptions occur
|
assert result is None # Task should return None even when exceptions occur
|
||||||
@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, [])
|
||||||
|
|
||||||
# Verify the task completed successfully
|
# Verify the task completed successfully
|
||||||
assert result is None
|
assert result is None
|
||||||
@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask:
|
|||||||
# Create large number of segments
|
# Create large number of segments
|
||||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
|
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
# Mock the index processor
|
# Mock the index processor
|
||||||
mock_processor = MagicMock()
|
mock_processor = MagicMock()
|
||||||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task
|
||||||
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
|
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
|
||||||
|
|
||||||
# Verify the task completed successfully
|
# Verify the task completed successfully
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|||||||
@ -0,0 +1,763 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
from enums.cloud_plan import CloudPlan
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
from tasks.duplicate_document_indexing_task import (
|
||||||
|
_duplicate_document_indexing_task, # Core function
|
||||||
|
_duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function
|
||||||
|
duplicate_document_indexing_task, # Deprecated old interface
|
||||||
|
normal_duplicate_document_indexing_task, # New normal task
|
||||||
|
priority_duplicate_document_indexing_task, # New priority task
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDuplicateDocumentIndexingTasks:
|
||||||
|
"""Integration tests for duplicate document indexing tasks using testcontainers.
|
||||||
|
|
||||||
|
This test class covers:
|
||||||
|
- Core _duplicate_document_indexing_task function
|
||||||
|
- Deprecated duplicate_document_indexing_task function
|
||||||
|
- New normal_duplicate_document_indexing_task function
|
||||||
|
- New priority_duplicate_document_indexing_task function
|
||||||
|
- Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function
|
||||||
|
- Document segment cleanup logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_external_service_dependencies(self):
|
||||||
|
"""Mock setup for external service dependencies."""
|
||||||
|
with (
|
||||||
|
patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||||
|
patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service,
|
||||||
|
patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||||
|
):
|
||||||
|
# Setup mock indexing runner
|
||||||
|
mock_runner_instance = MagicMock()
|
||||||
|
mock_indexing_runner.return_value = mock_runner_instance
|
||||||
|
|
||||||
|
# Setup mock feature service
|
||||||
|
mock_features = MagicMock()
|
||||||
|
mock_features.billing.enabled = False
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
|
||||||
|
# Setup mock index processor factory
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.clean = MagicMock()
|
||||||
|
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"indexing_runner": mock_indexing_runner,
|
||||||
|
"indexing_runner_instance": mock_runner_instance,
|
||||||
|
"feature_service": mock_feature_service,
|
||||||
|
"features": mock_features,
|
||||||
|
"index_processor_factory": mock_index_processor_factory,
|
||||||
|
"index_processor": mock_processor,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_test_dataset_and_documents(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to create a test dataset and documents for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session from testcontainers infrastructure
|
||||||
|
mock_external_service_dependencies: Mock dependencies
|
||||||
|
document_count: Number of documents to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (dataset, documents) - Created dataset and document instances
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create account and tenant
|
||||||
|
account = Account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db.session.add(account)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
tenant = Tenant(
|
||||||
|
name=fake.company(),
|
||||||
|
status="normal",
|
||||||
|
)
|
||||||
|
db.session.add(tenant)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create tenant-account join
|
||||||
|
join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db.session.add(join)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = Dataset(
|
||||||
|
id=fake.uuid4(),
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
name=fake.company(),
|
||||||
|
description=fake.text(max_nb_chars=100),
|
||||||
|
data_source_type="upload_file",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db.session.add(dataset)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create documents
|
||||||
|
documents = []
|
||||||
|
for i in range(document_count):
|
||||||
|
document = Document(
|
||||||
|
id=fake.uuid4(),
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
position=i,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
batch="test_batch",
|
||||||
|
name=fake.file_name(),
|
||||||
|
created_from="upload_file",
|
||||||
|
created_by=account.id,
|
||||||
|
indexing_status="waiting",
|
||||||
|
enabled=True,
|
||||||
|
doc_form="text_model",
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Refresh dataset to ensure it's properly loaded
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
|
||||||
|
return dataset, documents
|
||||||
|
|
||||||
|
def _create_test_dataset_with_segments(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to create a test dataset with documents and segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session from testcontainers infrastructure
|
||||||
|
mock_external_service_dependencies: Mock dependencies
|
||||||
|
document_count: Number of documents to create
|
||||||
|
segments_per_doc: Number of segments per document
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (dataset, documents, segments) - Created dataset, documents and segments
|
||||||
|
"""
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count
|
||||||
|
)
|
||||||
|
|
||||||
|
fake = Faker()
|
||||||
|
segments = []
|
||||||
|
|
||||||
|
# Create segments for each document
|
||||||
|
for document in documents:
|
||||||
|
for i in range(segments_per_doc):
|
||||||
|
segment = DocumentSegment(
|
||||||
|
id=fake.uuid4(),
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
document_id=document.id,
|
||||||
|
position=i,
|
||||||
|
index_node_id=f"{document.id}-node-{i}",
|
||||||
|
index_node_hash=fake.sha256(),
|
||||||
|
content=fake.text(max_nb_chars=200),
|
||||||
|
word_count=50,
|
||||||
|
tokens=100,
|
||||||
|
status="completed",
|
||||||
|
enabled=True,
|
||||||
|
indexing_at=fake.date_time_this_year(),
|
||||||
|
created_by=dataset.created_by, # Add required field
|
||||||
|
)
|
||||||
|
db.session.add(segment)
|
||||||
|
segments.append(segment)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Refresh to ensure all relationships are loaded
|
||||||
|
for document in documents:
|
||||||
|
db.session.refresh(document)
|
||||||
|
|
||||||
|
return dataset, documents, segments
|
||||||
|
|
||||||
|
def _create_test_dataset_with_billing_features(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to create a test dataset with billing features configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session from testcontainers infrastructure
|
||||||
|
mock_external_service_dependencies: Mock dependencies
|
||||||
|
billing_enabled: Whether billing is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (dataset, documents) - Created dataset and document instances
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create account and tenant
|
||||||
|
account = Account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db.session.add(account)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
tenant = Tenant(
|
||||||
|
name=fake.company(),
|
||||||
|
status="normal",
|
||||||
|
)
|
||||||
|
db.session.add(tenant)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create tenant-account join
|
||||||
|
join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db.session.add(join)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = Dataset(
|
||||||
|
id=fake.uuid4(),
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
name=fake.company(),
|
||||||
|
description=fake.text(max_nb_chars=100),
|
||||||
|
data_source_type="upload_file",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db.session.add(dataset)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Create documents
|
||||||
|
documents = []
|
||||||
|
for i in range(3):
|
||||||
|
document = Document(
|
||||||
|
id=fake.uuid4(),
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
position=i,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
batch="test_batch",
|
||||||
|
name=fake.file_name(),
|
||||||
|
created_from="upload_file",
|
||||||
|
created_by=account.id,
|
||||||
|
indexing_status="waiting",
|
||||||
|
enabled=True,
|
||||||
|
doc_form="text_model",
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Configure billing features
|
||||||
|
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||||
|
if billing_enabled:
|
||||||
|
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
|
||||||
|
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||||
|
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||||
|
|
||||||
|
# Refresh dataset to ensure it's properly loaded
|
||||||
|
db.session.refresh(dataset)
|
||||||
|
|
||||||
|
return dataset, documents
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_success(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test successful duplicate document indexing with multiple documents.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper dataset retrieval from database
|
||||||
|
- Correct document processing and status updates
|
||||||
|
- IndexingRunner integration
|
||||||
|
- Database state updates
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify the expected outcomes
|
||||||
|
# Verify indexing runner was called correctly
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Verify documents were updated to parsing status
|
||||||
|
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
assert updated_document.processing_started_at is not None
|
||||||
|
|
||||||
|
# Verify the run method was called with correct documents
|
||||||
|
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||||
|
assert call_args is not None
|
||||||
|
processed_documents = call_args[0][0] # First argument should be documents list
|
||||||
|
assert len(processed_documents) == 3
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_with_segment_cleanup(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test duplicate document indexing with existing segments that need cleanup.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Old segments are identified and cleaned
|
||||||
|
- Index processor clean method is called
|
||||||
|
- Segments are deleted from database
|
||||||
|
- New indexing proceeds after cleanup
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data with existing segments
|
||||||
|
dataset, documents, segments = self._create_test_dataset_with_segments(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify segment cleanup
|
||||||
|
# Verify index processor clean was called for each document with segments
|
||||||
|
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
|
||||||
|
|
||||||
|
# Verify segments were deleted from database
|
||||||
|
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
|
||||||
|
for segment in segments:
|
||||||
|
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||||
|
assert deleted_segment is None
|
||||||
|
|
||||||
|
# Verify documents were updated to parsing status
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
assert updated_document.processing_started_at is not None
|
||||||
|
|
||||||
|
# Verify indexing runner was called
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_dataset_not_found(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test handling of non-existent dataset.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for missing datasets
|
||||||
|
- Early return without processing
|
||||||
|
- Database session cleanup
|
||||||
|
- No unnecessary indexing runner calls
|
||||||
|
"""
|
||||||
|
# Arrange: Use non-existent dataset ID
|
||||||
|
fake = Faker()
|
||||||
|
non_existent_dataset_id = fake.uuid4()
|
||||||
|
document_ids = [fake.uuid4() for _ in range(3)]
|
||||||
|
|
||||||
|
# Act: Execute the task with non-existent dataset
|
||||||
|
_duplicate_document_indexing_task(non_existent_dataset_id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify no processing occurred
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||||
|
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_document_not_found_in_dataset(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test handling when some documents don't exist in the dataset.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Only existing documents are processed
|
||||||
|
- Non-existent documents are ignored
|
||||||
|
- Indexing runner receives only valid documents
|
||||||
|
- Database state updates correctly
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mix existing and non-existent document IDs
|
||||||
|
fake = Faker()
|
||||||
|
existing_document_ids = [doc.id for doc in documents]
|
||||||
|
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
|
||||||
|
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||||
|
|
||||||
|
# Act: Execute the task with mixed document IDs
|
||||||
|
_duplicate_document_indexing_task(dataset.id, all_document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify only existing documents were processed
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Verify only existing documents were updated
|
||||||
|
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||||
|
for doc_id in existing_document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
assert updated_document.processing_started_at is not None
|
||||||
|
|
||||||
|
# Verify the run method was called with only existing documents
|
||||||
|
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||||
|
assert call_args is not None
|
||||||
|
processed_documents = call_args[0][0] # First argument should be documents list
|
||||||
|
assert len(processed_documents) == 2 # Only existing documents
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_indexing_runner_exception(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test handling of IndexingRunner exceptions.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Exceptions from IndexingRunner are properly caught
|
||||||
|
- Task completes without raising exceptions
|
||||||
|
- Database session is properly closed
|
||||||
|
- Error logging occurs
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Mock IndexingRunner to raise an exception
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
|
||||||
|
"Indexing runner failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act: Execute the task
|
||||||
|
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify exception was handled gracefully
|
||||||
|
# The task should complete without raising exceptions
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Verify documents were still updated to parsing status before the exception
|
||||||
|
# Re-query documents from database since _duplicate_document_indexing_task close the session
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
assert updated_document.processing_started_at is not None
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test billing validation for sandbox plan batch upload limit.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Sandbox plan batch upload limit enforcement
|
||||||
|
- Error handling for batch upload limit exceeded
|
||||||
|
- Document status updates to error state
|
||||||
|
- Proper error message recording
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data with billing enabled
|
||||||
|
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure sandbox plan with batch limit
|
||||||
|
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
|
||||||
|
|
||||||
|
# Create more documents than sandbox plan allows (limit is 1)
|
||||||
|
fake = Faker()
|
||||||
|
extra_documents = []
|
||||||
|
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
|
||||||
|
document = Document(
|
||||||
|
id=fake.uuid4(),
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
position=i + 3,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
batch="test_batch",
|
||||||
|
name=fake.file_name(),
|
||||||
|
created_from="upload_file",
|
||||||
|
created_by=dataset.created_by,
|
||||||
|
indexing_status="waiting",
|
||||||
|
enabled=True,
|
||||||
|
doc_form="text_model",
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
extra_documents.append(document)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
all_documents = documents + extra_documents
|
||||||
|
document_ids = [doc.id for doc in all_documents]
|
||||||
|
|
||||||
|
# Act: Execute the task with too many documents for sandbox plan
|
||||||
|
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify error handling
|
||||||
|
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "error"
|
||||||
|
assert updated_document.error is not None
|
||||||
|
assert "batch upload" in updated_document.error.lower()
|
||||||
|
assert updated_document.stopped_at is not None
|
||||||
|
|
||||||
|
# Verify indexing runner was not called due to early validation error
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test billing validation for vector space limit.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Vector space limit enforcement
|
||||||
|
- Error handling for vector space limit exceeded
|
||||||
|
- Document status updates to error state
|
||||||
|
- Proper error message recording
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data with billing enabled
|
||||||
|
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure TEAM plan with vector space limit exceeded
|
||||||
|
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM
|
||||||
|
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||||
|
mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit
|
||||||
|
|
||||||
|
document_ids = [doc.id for doc in documents] # 3 documents will exceed limit
|
||||||
|
|
||||||
|
# Act: Execute the task with documents that will exceed vector space limit
|
||||||
|
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify error handling
|
||||||
|
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "error"
|
||||||
|
assert updated_document.error is not None
|
||||||
|
assert "limit" in updated_document.error.lower()
|
||||||
|
assert updated_document.stopped_at is not None
|
||||||
|
|
||||||
|
# Verify indexing runner was not called due to early validation error
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||||
|
|
||||||
|
def test_duplicate_document_indexing_task_with_empty_document_list(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test handling of empty document list.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Empty document list is handled gracefully
|
||||||
|
- No processing occurs
|
||||||
|
- No errors are raised
|
||||||
|
- Database session is properly closed
|
||||||
|
"""
|
||||||
|
# Arrange: Create test dataset
|
||||||
|
dataset, _ = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=0
|
||||||
|
)
|
||||||
|
document_ids = []
|
||||||
|
|
||||||
|
# Act: Execute the task with empty document list
|
||||||
|
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify IndexingRunner was called with empty list
|
||||||
|
# Note: The actual implementation does call run([]) with empty list
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([])
|
||||||
|
|
||||||
|
def test_deprecated_duplicate_document_indexing_task_delegates_to_core(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that deprecated duplicate_document_indexing_task delegates to core function.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Deprecated function calls core _duplicate_document_indexing_task
|
||||||
|
- Proper parameter passing
|
||||||
|
- Backward compatibility
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Act: Execute the deprecated task
|
||||||
|
duplicate_document_indexing_task(dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify core function was executed
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Clear session cache to see database updates from task's session
|
||||||
|
db.session.expire_all()
|
||||||
|
|
||||||
|
# Verify documents were processed
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
|
||||||
|
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||||
|
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
|
||||||
|
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test normal_duplicate_document_indexing_task with tenant isolation queue.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Task uses tenant isolation queue correctly
|
||||||
|
- Core processing function is called
|
||||||
|
- Queue management (pull tasks, delete key) works properly
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Mock tenant isolated queue to return no next tasks
|
||||||
|
mock_queue = MagicMock()
|
||||||
|
mock_queue.pull_tasks.return_value = []
|
||||||
|
mock_queue_class.return_value = mock_queue
|
||||||
|
|
||||||
|
# Act: Execute the normal task
|
||||||
|
normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify processing occurred
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Verify tenant queue was used
|
||||||
|
mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing")
|
||||||
|
mock_queue.pull_tasks.assert_called_once()
|
||||||
|
mock_queue.delete_task_key.assert_called_once()
|
||||||
|
|
||||||
|
# Clear session cache to see database updates from task's session
|
||||||
|
db.session.expire_all()
|
||||||
|
|
||||||
|
# Verify documents were processed
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
|
||||||
|
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||||
|
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
|
||||||
|
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test priority_duplicate_document_indexing_task with tenant isolation queue.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Task uses tenant isolation queue correctly
|
||||||
|
- Core processing function is called
|
||||||
|
- Queue management works properly
|
||||||
|
- Same behavior as normal task with different queue assignment
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Mock tenant isolated queue to return no next tasks
|
||||||
|
mock_queue = MagicMock()
|
||||||
|
mock_queue.pull_tasks.return_value = []
|
||||||
|
mock_queue_class.return_value = mock_queue
|
||||||
|
|
||||||
|
# Act: Execute the priority task
|
||||||
|
priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
|
||||||
|
|
||||||
|
# Assert: Verify processing occurred
|
||||||
|
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||||
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Verify tenant queue was used
|
||||||
|
mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing")
|
||||||
|
mock_queue.pull_tasks.assert_called_once()
|
||||||
|
mock_queue.delete_task_key.assert_called_once()
|
||||||
|
|
||||||
|
# Clear session cache to see database updates from task's session
|
||||||
|
db.session.expire_all()
|
||||||
|
|
||||||
|
# Verify documents were processed
|
||||||
|
for doc_id in document_ids:
|
||||||
|
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||||
|
assert updated_document.indexing_status == "parsing"
|
||||||
|
|
||||||
|
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||||
|
def test_tenant_queue_wrapper_processes_next_tasks(
|
||||||
|
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test tenant queue wrapper processes next queued tasks.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- After completing current task, next tasks are pulled from queue
|
||||||
|
- Next tasks are executed correctly
|
||||||
|
- Task waiting time is set for next tasks
|
||||||
|
"""
|
||||||
|
# Arrange: Create test data
|
||||||
|
dataset, documents = self._create_test_dataset_and_documents(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||||
|
)
|
||||||
|
document_ids = [doc.id for doc in documents]
|
||||||
|
|
||||||
|
# Extract values before session detachment
|
||||||
|
tenant_id = dataset.tenant_id
|
||||||
|
dataset_id = dataset.id
|
||||||
|
|
||||||
|
# Mock tenant isolated queue to return next task
|
||||||
|
mock_queue = MagicMock()
|
||||||
|
next_task = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"document_ids": document_ids,
|
||||||
|
}
|
||||||
|
mock_queue.pull_tasks.return_value = [next_task]
|
||||||
|
mock_queue_class.return_value = mock_queue
|
||||||
|
|
||||||
|
# Mock the task function to track calls
|
||||||
|
mock_task_func = MagicMock()
|
||||||
|
|
||||||
|
# Act: Execute the wrapper function
|
||||||
|
_duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func)
|
||||||
|
|
||||||
|
# Assert: Verify next task was scheduled
|
||||||
|
mock_queue.pull_tasks.assert_called_once()
|
||||||
|
mock_queue.set_task_waiting_time.assert_called_once()
|
||||||
|
mock_task_func.delay.assert_called_once_with(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
document_ids=document_ids,
|
||||||
|
)
|
||||||
|
mock_queue.delete_task_key.assert_not_called()
|
||||||
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask:
|
|||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
indexing_status="completed",
|
indexing_status="completed",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
doc_form=IndexType.PARAGRAPH_INDEX,
|
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||||
)
|
)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update document to use different index type
|
# Update document to use different index type
|
||||||
document.doc_form = IndexType.QA_INDEX
|
document.doc_form = IndexStructureType.QA_INDEX
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||||
@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask:
|
|||||||
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
|
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
|
||||||
|
|
||||||
# Assert: Verify different index type handling
|
# Assert: Verify different index type handling
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.QA_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
# Verify the load method was called with correct parameters
|
# Verify the load method was called with correct parameters
|
||||||
@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask:
|
|||||||
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
|
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
|
||||||
|
|
||||||
# Assert: Verify index processor was created but load was not called
|
# Assert: Verify index processor was created but load was not called
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
|
IndexStructureType.PARAGRAPH_INDEX
|
||||||
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||||
|
|
||||||
def test_enable_segments_to_index_with_parent_child_structure(
|
def test_enable_segments_to_index_with_parent_child_structure(
|
||||||
@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update document to use parent-child index type
|
# Update document to use parent-child index type
|
||||||
document.doc_form = IndexType.PARENT_CHILD_INDEX
|
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||||
@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask:
|
|||||||
|
|
||||||
# Assert: Verify parent-child index processing
|
# Assert: Verify parent-child index processing
|
||||||
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
|
||||||
IndexType.PARENT_CHILD_INDEX
|
IndexStructureType.PARENT_CHILD_INDEX
|
||||||
)
|
)
|
||||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
|
|
||||||
from core.entities.embedding_type import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeConnectionError,
|
InvokeConnectionError,
|
||||||
@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_embedding_result(self):
|
def sample_embedding_result(self):
|
||||||
"""Create a sample TextEmbeddingResult for testing.
|
"""Create a sample EmbeddingResult for testing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TextEmbeddingResult: Mock embedding result with proper structure
|
EmbeddingResult: Mock embedding result with proper structure
|
||||||
"""
|
"""
|
||||||
# Create normalized embedding vectors (dimension 1536 for ada-002)
|
# Create normalized embedding vectors (dimension 1536 for ada-002)
|
||||||
embedding_vector = np.random.randn(1536)
|
embedding_vector = np.random.randn(1536)
|
||||||
@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized_vector],
|
embeddings=[normalized_vector],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments:
|
|||||||
latency=0.8,
|
latency=0.8,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments:
|
|||||||
latency=0.6,
|
latency=0.6,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=new_embeddings,
|
embeddings=new_embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[valid_vector.tolist(), nan_vector],
|
embeddings=[valid_vector.tolist(), nan_vector],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[nan_vector],
|
embeddings=[nan_vector],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_ada = TextEmbeddingResult(
|
result_ada = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized_ada],
|
embeddings=[normalized_ada],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_3_small = TextEmbeddingResult(
|
result_3_small = EmbeddingResult(
|
||||||
model="text-embedding-3-small",
|
model="text-embedding-3-small",
|
||||||
embeddings=[normalized_3_small],
|
embeddings=[normalized_3_small],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching:
|
|||||||
latency=0.4,
|
latency=0.4,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_openai = TextEmbeddingResult(
|
result_openai = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized_openai],
|
embeddings=[normalized_openai],
|
||||||
usage=usage_openai,
|
usage=usage_openai,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_cohere = TextEmbeddingResult(
|
result_cohere = EmbeddingResult(
|
||||||
model="embed-english-v3.0",
|
model="embed-english-v3.0",
|
||||||
embeddings=[normalized_cohere],
|
embeddings=[normalized_cohere],
|
||||||
usage=usage_cohere,
|
usage=usage_cohere,
|
||||||
@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation:
|
|||||||
latency=0.7,
|
latency=0.7,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_ada = TextEmbeddingResult(
|
result_ada = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized_ada],
|
embeddings=[normalized_ada],
|
||||||
usage=usage_ada,
|
usage=usage_ada,
|
||||||
@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation:
|
|||||||
latency=0.4,
|
latency=0.4,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_cohere = TextEmbeddingResult(
|
result_cohere = EmbeddingResult(
|
||||||
model="embed-english-v3.0",
|
model="embed-english-v3.0",
|
||||||
embeddings=[normalized_cohere],
|
embeddings=[normalized_cohere],
|
||||||
usage=usage_cohere,
|
usage=usage_cohere,
|
||||||
@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=0.1,
|
latency=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=1.5,
|
latency=1.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=0.2,
|
latency=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Model returns embeddings for all texts
|
# Model returns embeddings for all texts
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=0.8,
|
latency=0.8,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance:
|
|||||||
latency=0.5,
|
latency=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance:
|
|||||||
latency=0.3,
|
latency=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_result = TextEmbeddingResult(
|
embedding_result = EmbeddingResult(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
embeddings=[normalized],
|
embeddings=[normalized],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
|||||||
@ -62,7 +62,7 @@ from core.indexing_runner import (
|
|||||||
IndexingRunner,
|
IndexingRunner,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import ChildDocument, Document
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
@ -112,7 +112,7 @@ def create_mock_dataset_document(
|
|||||||
document_id: str | None = None,
|
document_id: str | None = None,
|
||||||
dataset_id: str | None = None,
|
dataset_id: str | None = None,
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
doc_form: str = IndexType.PARAGRAPH_INDEX,
|
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||||
data_source_type: str = "upload_file",
|
data_source_type: str = "upload_file",
|
||||||
doc_language: str = "English",
|
doc_language: str = "English",
|
||||||
) -> Mock:
|
) -> Mock:
|
||||||
@ -133,8 +133,8 @@ def create_mock_dataset_document(
|
|||||||
Mock: A configured mock DatasetDocument object with all required attributes.
|
Mock: A configured mock DatasetDocument object with all required attributes.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX)
|
>>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX)
|
||||||
>>> assert doc.doc_form == IndexType.QA_INDEX
|
>>> assert doc.doc_form == IndexStructureType.QA_INDEX
|
||||||
"""
|
"""
|
||||||
doc = Mock(spec=DatasetDocument)
|
doc = Mock(spec=DatasetDocument)
|
||||||
doc.id = document_id or str(uuid.uuid4())
|
doc.id = document_id or str(uuid.uuid4())
|
||||||
@ -276,7 +276,7 @@ class TestIndexingRunnerExtract:
|
|||||||
doc.id = str(uuid.uuid4())
|
doc.id = str(uuid.uuid4())
|
||||||
doc.dataset_id = str(uuid.uuid4())
|
doc.dataset_id = str(uuid.uuid4())
|
||||||
doc.tenant_id = str(uuid.uuid4())
|
doc.tenant_id = str(uuid.uuid4())
|
||||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||||
doc.data_source_type = "upload_file"
|
doc.data_source_type = "upload_file"
|
||||||
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
|
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
|
||||||
return doc
|
return doc
|
||||||
@ -616,7 +616,7 @@ class TestIndexingRunnerLoad:
|
|||||||
doc = Mock(spec=DatasetDocument)
|
doc = Mock(spec=DatasetDocument)
|
||||||
doc.id = str(uuid.uuid4())
|
doc.id = str(uuid.uuid4())
|
||||||
doc.dataset_id = str(uuid.uuid4())
|
doc.dataset_id = str(uuid.uuid4())
|
||||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -700,7 +700,7 @@ class TestIndexingRunnerLoad:
|
|||||||
"""Test loading with parent-child index structure."""
|
"""Test loading with parent-child index structure."""
|
||||||
# Arrange
|
# Arrange
|
||||||
runner = IndexingRunner()
|
runner = IndexingRunner()
|
||||||
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
|
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||||
sample_dataset.indexing_technique = "high_quality"
|
sample_dataset.indexing_technique = "high_quality"
|
||||||
|
|
||||||
# Add child documents
|
# Add child documents
|
||||||
@ -775,7 +775,7 @@ class TestIndexingRunnerRun:
|
|||||||
doc.id = str(uuid.uuid4())
|
doc.id = str(uuid.uuid4())
|
||||||
doc.dataset_id = str(uuid.uuid4())
|
doc.dataset_id = str(uuid.uuid4())
|
||||||
doc.tenant_id = str(uuid.uuid4())
|
doc.tenant_id = str(uuid.uuid4())
|
||||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||||
doc.doc_language = "English"
|
doc.doc_language = "English"
|
||||||
doc.data_source_type = "upload_file"
|
doc.data_source_type = "upload_file"
|
||||||
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
|
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
|
||||||
@ -802,6 +802,21 @@ class TestIndexingRunnerRun:
|
|||||||
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
|
||||||
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
|
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
|
||||||
|
|
||||||
|
# Mock current_user (Account) for _transform
|
||||||
|
mock_current_user = MagicMock()
|
||||||
|
mock_current_user.set_tenant_id = MagicMock()
|
||||||
|
|
||||||
|
# Setup db.session.query to return different results based on the model
|
||||||
|
def mock_query_side_effect(model):
|
||||||
|
mock_query_result = MagicMock()
|
||||||
|
if model.__name__ == "Dataset":
|
||||||
|
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
|
||||||
|
elif model.__name__ == "Account":
|
||||||
|
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
|
||||||
|
return mock_query_result
|
||||||
|
|
||||||
|
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
|
||||||
|
|
||||||
# Mock processor
|
# Mock processor
|
||||||
mock_processor = MagicMock()
|
mock_processor = MagicMock()
|
||||||
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
|
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
|
||||||
@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments:
|
|||||||
doc.id = str(uuid.uuid4())
|
doc.id = str(uuid.uuid4())
|
||||||
doc.dataset_id = str(uuid.uuid4())
|
doc.dataset_id = str(uuid.uuid4())
|
||||||
doc.created_by = str(uuid.uuid4())
|
doc.created_by = str(uuid.uuid4())
|
||||||
doc.doc_form = IndexType.PARAGRAPH_INDEX
|
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments:
|
|||||||
"""Test loading segments for parent-child index."""
|
"""Test loading segments for parent-child index."""
|
||||||
# Arrange
|
# Arrange
|
||||||
runner = IndexingRunner()
|
runner = IndexingRunner()
|
||||||
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
|
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||||
|
|
||||||
# Add child documents
|
# Add child documents
|
||||||
for doc in sample_documents:
|
for doc in sample_documents:
|
||||||
@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
extract_settings=extract_settings,
|
extract_settings=extract_settings,
|
||||||
tmp_processing_rule={"mode": "automatic", "rules": {}},
|
tmp_processing_rule={"mode": "automatic", "rules": {}},
|
||||||
doc_form=IndexType.PARAGRAPH_INDEX,
|
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode
|
|||||||
from core.rag.rerank.weight_rerank import WeightRerankRunner
|
from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_model_instance():
|
||||||
|
"""Create a properly configured mock ModelInstance for reranking tests."""
|
||||||
|
mock_instance = Mock(spec=ModelInstance)
|
||||||
|
# Setup provider_model_bundle chain for check_model_support_vision
|
||||||
|
mock_instance.provider_model_bundle = Mock()
|
||||||
|
mock_instance.provider_model_bundle.configuration = Mock()
|
||||||
|
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||||
|
mock_instance.provider = "test-provider"
|
||||||
|
mock_instance.model = "test-model"
|
||||||
|
return mock_instance
|
||||||
|
|
||||||
|
|
||||||
class TestRerankModelRunner:
|
class TestRerankModelRunner:
|
||||||
"""Unit tests for RerankModelRunner.
|
"""Unit tests for RerankModelRunner.
|
||||||
|
|
||||||
@ -37,10 +49,23 @@ class TestRerankModelRunner:
|
|||||||
- Metadata preservation and score injection
|
- Metadata preservation and score injection
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_model_manager(self):
|
||||||
|
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||||
|
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||||
|
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||||
|
yield mock_mm
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_model_instance(self):
|
def mock_model_instance(self):
|
||||||
"""Create a mock ModelInstance for reranking."""
|
"""Create a mock ModelInstance for reranking."""
|
||||||
mock_instance = Mock(spec=ModelInstance)
|
mock_instance = Mock(spec=ModelInstance)
|
||||||
|
# Setup provider_model_bundle chain for check_model_support_vision
|
||||||
|
mock_instance.provider_model_bundle = Mock()
|
||||||
|
mock_instance.provider_model_bundle.configuration = Mock()
|
||||||
|
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||||
|
mock_instance.provider = "test-provider"
|
||||||
|
mock_instance.model = "test-model"
|
||||||
return mock_instance
|
return mock_instance
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -803,7 +828,7 @@ class TestRerankRunnerFactory:
|
|||||||
- Parameters are forwarded to runner constructor
|
- Parameters are forwarded to runner constructor
|
||||||
"""
|
"""
|
||||||
# Arrange: Mock model instance
|
# Arrange: Mock model instance
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
|
|
||||||
# Act: Create runner via factory
|
# Act: Create runner via factory
|
||||||
runner = RerankRunnerFactory.create_rerank_runner(
|
runner = RerankRunnerFactory.create_rerank_runner(
|
||||||
@ -865,7 +890,7 @@ class TestRerankRunnerFactory:
|
|||||||
- String values are properly matched
|
- String values are properly matched
|
||||||
"""
|
"""
|
||||||
# Arrange: Mock model instance
|
# Arrange: Mock model instance
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
|
|
||||||
# Act: Create runner using enum value
|
# Act: Create runner using enum value
|
||||||
runner = RerankRunnerFactory.create_rerank_runner(
|
runner = RerankRunnerFactory.create_rerank_runner(
|
||||||
@ -886,6 +911,13 @@ class TestRerankIntegration:
|
|||||||
- Real-world usage scenarios
|
- Real-world usage scenarios
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_model_manager(self):
|
||||||
|
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||||
|
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||||
|
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||||
|
yield mock_mm
|
||||||
|
|
||||||
def test_model_reranking_full_workflow(self):
|
def test_model_reranking_full_workflow(self):
|
||||||
"""Test complete model-based reranking workflow.
|
"""Test complete model-based reranking workflow.
|
||||||
|
|
||||||
@ -895,7 +927,7 @@ class TestRerankIntegration:
|
|||||||
- Top results are returned correctly
|
- Top results are returned correctly
|
||||||
"""
|
"""
|
||||||
# Arrange: Create mock model and documents
|
# Arrange: Create mock model and documents
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -951,7 +983,7 @@ class TestRerankIntegration:
|
|||||||
- Normalization is consistent
|
- Normalization is consistent
|
||||||
"""
|
"""
|
||||||
# Arrange: Create mock model with various scores
|
# Arrange: Create mock model with various scores
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -991,6 +1023,13 @@ class TestRerankEdgeCases:
|
|||||||
- Concurrent reranking scenarios
|
- Concurrent reranking scenarios
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_model_manager(self):
|
||||||
|
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||||
|
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||||
|
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||||
|
yield mock_mm
|
||||||
|
|
||||||
def test_rerank_with_empty_metadata(self):
|
def test_rerank_with_empty_metadata(self):
|
||||||
"""Test reranking when documents have empty metadata.
|
"""Test reranking when documents have empty metadata.
|
||||||
|
|
||||||
@ -1000,7 +1039,7 @@ class TestRerankEdgeCases:
|
|||||||
- Empty metadata documents are processed correctly
|
- Empty metadata documents are processed correctly
|
||||||
"""
|
"""
|
||||||
# Arrange: Create documents with empty metadata
|
# Arrange: Create documents with empty metadata
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -1046,7 +1085,7 @@ class TestRerankEdgeCases:
|
|||||||
- Score comparison logic works at boundary
|
- Score comparison logic works at boundary
|
||||||
"""
|
"""
|
||||||
# Arrange: Create mock with various scores including negatives
|
# Arrange: Create mock with various scores including negatives
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -1082,7 +1121,7 @@ class TestRerankEdgeCases:
|
|||||||
- No overflow or precision issues
|
- No overflow or precision issues
|
||||||
"""
|
"""
|
||||||
# Arrange: All documents with perfect scores
|
# Arrange: All documents with perfect scores
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -1117,7 +1156,7 @@ class TestRerankEdgeCases:
|
|||||||
- Content encoding is preserved
|
- Content encoding is preserved
|
||||||
"""
|
"""
|
||||||
# Arrange: Documents with special characters
|
# Arrange: Documents with special characters
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -1159,7 +1198,7 @@ class TestRerankEdgeCases:
|
|||||||
- Content is not truncated unexpectedly
|
- Content is not truncated unexpectedly
|
||||||
"""
|
"""
|
||||||
# Arrange: Documents with very long content
|
# Arrange: Documents with very long content
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
long_content = "This is a very long document. " * 1000 # ~30,000 characters
|
long_content = "This is a very long document. " * 1000 # ~30,000 characters
|
||||||
|
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
@ -1196,7 +1235,7 @@ class TestRerankEdgeCases:
|
|||||||
- All documents are processed correctly
|
- All documents are processed correctly
|
||||||
"""
|
"""
|
||||||
# Arrange: Create 100 documents
|
# Arrange: Create 100 documents
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
num_docs = 100
|
num_docs = 100
|
||||||
|
|
||||||
# Create rerank results for all documents
|
# Create rerank results for all documents
|
||||||
@ -1287,7 +1326,7 @@ class TestRerankEdgeCases:
|
|||||||
- Documents can still be ranked
|
- Documents can still be ranked
|
||||||
"""
|
"""
|
||||||
# Arrange: Empty query
|
# Arrange: Empty query
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
@ -1325,6 +1364,13 @@ class TestRerankPerformance:
|
|||||||
- Score calculation optimization
|
- Score calculation optimization
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_model_manager(self):
|
||||||
|
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||||
|
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||||
|
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||||
|
yield mock_mm
|
||||||
|
|
||||||
def test_rerank_batch_processing(self):
|
def test_rerank_batch_processing(self):
|
||||||
"""Test that documents are processed in a single batch.
|
"""Test that documents are processed in a single batch.
|
||||||
|
|
||||||
@ -1334,7 +1380,7 @@ class TestRerankPerformance:
|
|||||||
- Efficient batch processing
|
- Efficient batch processing
|
||||||
"""
|
"""
|
||||||
# Arrange: Multiple documents
|
# Arrange: Multiple documents
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
|
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
|
||||||
@ -1435,6 +1481,13 @@ class TestRerankErrorHandling:
|
|||||||
- Error propagation
|
- Error propagation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_model_manager(self):
|
||||||
|
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||||
|
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||||
|
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||||
|
yield mock_mm
|
||||||
|
|
||||||
def test_rerank_model_invocation_error(self):
|
def test_rerank_model_invocation_error(self):
|
||||||
"""Test handling of model invocation errors.
|
"""Test handling of model invocation errors.
|
||||||
|
|
||||||
@ -1444,7 +1497,7 @@ class TestRerankErrorHandling:
|
|||||||
- Error context is preserved
|
- Error context is preserved
|
||||||
"""
|
"""
|
||||||
# Arrange: Mock model that raises exception
|
# Arrange: Mock model that raises exception
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
|
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
@ -1470,7 +1523,7 @@ class TestRerankErrorHandling:
|
|||||||
- Invalid results don't corrupt output
|
- Invalid results don't corrupt output
|
||||||
"""
|
"""
|
||||||
# Arrange: Rerank result with invalid index
|
# Arrange: Rerank result with invalid index
|
||||||
mock_model_instance = Mock(spec=ModelInstance)
|
mock_model_instance = create_mock_model_instance()
|
||||||
mock_rerank_result = RerankResult(
|
mock_rerank_result = RerankResult(
|
||||||
model="bge-reranker-base",
|
model="bge-reranker-base",
|
||||||
docs=[
|
docs=[
|
||||||
|
|||||||
@ -425,15 +425,15 @@ class TestRetrievalService:
|
|||||||
|
|
||||||
# ==================== Vector Search Tests ====================
|
# ==================== Vector Search Tests ====================
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
|
def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||||
"""
|
"""
|
||||||
Test basic vector/semantic search functionality.
|
Test basic vector/semantic search functionality.
|
||||||
|
|
||||||
This test validates the core vector search flow:
|
This test validates the core vector search flow:
|
||||||
1. Dataset is retrieved from database
|
1. Dataset is retrieved from database
|
||||||
2. embedding_search is called via ThreadPoolExecutor
|
2. _retrieve is called via ThreadPoolExecutor
|
||||||
3. Documents are added to shared all_documents list
|
3. Documents are added to shared all_documents list
|
||||||
4. Results are returned to caller
|
4. Results are returned to caller
|
||||||
|
|
||||||
@ -447,28 +447,28 @@ class TestRetrievalService:
|
|||||||
# Set up the mock dataset that will be "retrieved" from database
|
# Set up the mock dataset that will be "retrieved" from database
|
||||||
mock_get_dataset.return_value = mock_dataset
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
# Create a side effect function that simulates embedding_search behavior
|
# Create a side effect function that simulates _retrieve behavior
|
||||||
# In the real implementation, embedding_search:
|
# _retrieve modifies the all_documents list in place
|
||||||
# 1. Gets the dataset
|
def side_effect_retrieve(
|
||||||
# 2. Creates a Vector instance
|
|
||||||
# 3. Calls search_by_vector with embeddings
|
|
||||||
# 4. Extends all_documents with results
|
|
||||||
def side_effect_embedding_search(
|
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
"""Simulate embedding_search adding documents to the shared list."""
|
"""Simulate _retrieve adding documents to the shared list."""
|
||||||
all_documents.extend(sample_documents)
|
if all_documents is not None:
|
||||||
|
all_documents.extend(sample_documents)
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_embedding_search
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
|
|
||||||
# Define test parameters
|
# Define test parameters
|
||||||
query = "What is Python?" # Natural language query
|
query = "What is Python?" # Natural language query
|
||||||
@ -481,7 +481,7 @@ class TestRetrievalService:
|
|||||||
# 1. Check if query is empty (early return if so)
|
# 1. Check if query is empty (early return if so)
|
||||||
# 2. Get the dataset using _get_dataset
|
# 2. Get the dataset using _get_dataset
|
||||||
# 3. Create ThreadPoolExecutor
|
# 3. Create ThreadPoolExecutor
|
||||||
# 4. Submit embedding_search task
|
# 4. Submit _retrieve task
|
||||||
# 5. Wait for completion
|
# 5. Wait for completion
|
||||||
# 6. Return all_documents list
|
# 6. Return all_documents list
|
||||||
results = RetrievalService.retrieve(
|
results = RetrievalService.retrieve(
|
||||||
@ -502,15 +502,13 @@ class TestRetrievalService:
|
|||||||
# Verify documents maintain their scores (highest score first in sample_documents)
|
# Verify documents maintain their scores (highest score first in sample_documents)
|
||||||
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
|
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
|
||||||
|
|
||||||
# Verify embedding_search was called exactly once
|
# Verify _retrieve was called exactly once
|
||||||
# This confirms the search method was invoked by ThreadPoolExecutor
|
# This confirms the search method was invoked by ThreadPoolExecutor
|
||||||
mock_embedding_search.assert_called_once()
|
mock_retrieve.assert_called_once()
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_vector_search_with_document_filter(
|
def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||||
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Test vector search with document ID filtering.
|
Test vector search with document ID filtering.
|
||||||
|
|
||||||
@ -522,21 +520,25 @@ class TestRetrievalService:
|
|||||||
mock_get_dataset.return_value = mock_dataset
|
mock_get_dataset.return_value = mock_dataset
|
||||||
filtered_docs = [sample_documents[0]]
|
filtered_docs = [sample_documents[0]]
|
||||||
|
|
||||||
def side_effect_embedding_search(
|
def side_effect_retrieve(
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
all_documents.extend(filtered_docs)
|
if all_documents is not None:
|
||||||
|
all_documents.extend(filtered_docs)
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_embedding_search
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
document_ids_filter = [sample_documents[0].metadata["document_id"]]
|
document_ids_filter = [sample_documents[0].metadata["document_id"]]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -552,12 +554,12 @@ class TestRetrievalService:
|
|||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].metadata["doc_id"] == "doc1"
|
assert results[0].metadata["doc_id"] == "doc1"
|
||||||
# Verify document_ids_filter was passed
|
# Verify document_ids_filter was passed
|
||||||
call_kwargs = mock_embedding_search.call_args.kwargs
|
call_kwargs = mock_retrieve.call_args.kwargs
|
||||||
assert call_kwargs["document_ids_filter"] == document_ids_filter
|
assert call_kwargs["document_ids_filter"] == document_ids_filter
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||||
"""
|
"""
|
||||||
Test vector search when no results match the query.
|
Test vector search when no results match the query.
|
||||||
|
|
||||||
@ -567,8 +569,8 @@ class TestRetrievalService:
|
|||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_dataset.return_value = mock_dataset
|
mock_get_dataset.return_value = mock_dataset
|
||||||
# embedding_search doesn't add anything to all_documents
|
# _retrieve doesn't add anything to all_documents
|
||||||
mock_embedding_search.side_effect = lambda *args, **kwargs: None
|
mock_retrieve.side_effect = lambda *args, **kwargs: None
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
results = RetrievalService.retrieve(
|
results = RetrievalService.retrieve(
|
||||||
@ -583,9 +585,9 @@ class TestRetrievalService:
|
|||||||
|
|
||||||
# ==================== Keyword Search Tests ====================
|
# ==================== Keyword Search Tests ====================
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
|
def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||||
"""
|
"""
|
||||||
Test basic keyword search functionality.
|
Test basic keyword search functionality.
|
||||||
|
|
||||||
@ -597,12 +599,25 @@ class TestRetrievalService:
|
|||||||
# Arrange
|
# Arrange
|
||||||
mock_get_dataset.return_value = mock_dataset
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
def side_effect_keyword_search(
|
def side_effect_retrieve(
|
||||||
flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
|
flask_app,
|
||||||
|
retrieval_method,
|
||||||
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
all_documents.extend(sample_documents)
|
if all_documents is not None:
|
||||||
|
all_documents.extend(sample_documents)
|
||||||
|
|
||||||
mock_keyword_search.side_effect = side_effect_keyword_search
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
|
|
||||||
query = "Python programming"
|
query = "Python programming"
|
||||||
top_k = 3
|
top_k = 3
|
||||||
@ -618,7 +633,7 @@ class TestRetrievalService:
|
|||||||
# Assert
|
# Assert
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
assert all(isinstance(doc, Document) for doc in results)
|
assert all(isinstance(doc, Document) for doc in results)
|
||||||
mock_keyword_search.assert_called_once()
|
mock_retrieve.assert_called_once()
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
@ -1147,11 +1162,9 @@ class TestRetrievalService:
|
|||||||
|
|
||||||
# ==================== Metadata Filtering Tests ====================
|
# ==================== Metadata Filtering Tests ====================
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_vector_search_with_metadata_filter(
|
def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||||
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Test vector search with metadata-based document filtering.
|
Test vector search with metadata-based document filtering.
|
||||||
|
|
||||||
@ -1166,21 +1179,25 @@ class TestRetrievalService:
|
|||||||
filtered_doc = sample_documents[0]
|
filtered_doc = sample_documents[0]
|
||||||
filtered_doc.metadata["category"] = "programming"
|
filtered_doc.metadata["category"] = "programming"
|
||||||
|
|
||||||
def side_effect_embedding(
|
def side_effect_retrieve(
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
all_documents.append(filtered_doc)
|
if all_documents is not None:
|
||||||
|
all_documents.append(filtered_doc)
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_embedding
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
results = RetrievalService.retrieve(
|
results = RetrievalService.retrieve(
|
||||||
@ -1243,9 +1260,9 @@ class TestRetrievalService:
|
|||||||
# Assert
|
# Assert
|
||||||
assert results == []
|
assert results == []
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||||
"""
|
"""
|
||||||
Test that exceptions during retrieval are properly handled.
|
Test that exceptions during retrieval are properly handled.
|
||||||
|
|
||||||
@ -1256,22 +1273,26 @@ class TestRetrievalService:
|
|||||||
# Arrange
|
# Arrange
|
||||||
mock_get_dataset.return_value = mock_dataset
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
# Make embedding_search add an exception to the exceptions list
|
# Make _retrieve add an exception to the exceptions list
|
||||||
def side_effect_with_exception(
|
def side_effect_with_exception(
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
exceptions.append("Search failed")
|
if exceptions is not None:
|
||||||
|
exceptions.append("Search failed")
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_with_exception
|
mock_retrieve.side_effect = side_effect_with_exception
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
@ -1286,9 +1307,9 @@ class TestRetrievalService:
|
|||||||
|
|
||||||
# ==================== Score Threshold Tests ====================
|
# ==================== Score Threshold Tests ====================
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||||
"""
|
"""
|
||||||
Test vector search with score threshold filtering.
|
Test vector search with score threshold filtering.
|
||||||
|
|
||||||
@ -1306,21 +1327,25 @@ class TestRetrievalService:
|
|||||||
provider="dify",
|
provider="dify",
|
||||||
)
|
)
|
||||||
|
|
||||||
def side_effect_embedding(
|
def side_effect_retrieve(
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
all_documents.append(high_score_doc)
|
if all_documents is not None:
|
||||||
|
all_documents.append(high_score_doc)
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_embedding
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
|
|
||||||
score_threshold = 0.8
|
score_threshold = 0.8
|
||||||
|
|
||||||
@ -1339,9 +1364,9 @@ class TestRetrievalService:
|
|||||||
|
|
||||||
# ==================== Top-K Limiting Tests ====================
|
# ==================== Top-K Limiting Tests ====================
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
|
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset):
|
||||||
"""
|
"""
|
||||||
Test that retrieval respects top_k parameter.
|
Test that retrieval respects top_k parameter.
|
||||||
|
|
||||||
@ -1362,22 +1387,26 @@ class TestRetrievalService:
|
|||||||
for i in range(10)
|
for i in range(10)
|
||||||
]
|
]
|
||||||
|
|
||||||
def side_effect_embedding(
|
def side_effect_retrieve(
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
# Return only top_k documents
|
# Return only top_k documents
|
||||||
all_documents.extend(many_docs[:top_k])
|
if all_documents is not None:
|
||||||
|
all_documents.extend(many_docs[:top_k])
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_embedding
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
|
|
||||||
top_k = 3
|
top_k = 3
|
||||||
|
|
||||||
@ -1390,9 +1419,9 @@ class TestRetrievalService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify top_k was passed to embedding_search
|
# Verify _retrieve was called
|
||||||
assert mock_embedding_search.called
|
assert mock_retrieve.called
|
||||||
call_kwargs = mock_embedding_search.call_args.kwargs
|
call_kwargs = mock_retrieve.call_args.kwargs
|
||||||
assert call_kwargs["top_k"] == top_k
|
assert call_kwargs["top_k"] == top_k
|
||||||
# Verify we got the right number of results
|
# Verify we got the right number of results
|
||||||
assert len(results) == top_k
|
assert len(results) == top_k
|
||||||
@ -1421,11 +1450,9 @@ class TestRetrievalService:
|
|||||||
|
|
||||||
# ==================== Reranking Tests ====================
|
# ==================== Reranking Tests ====================
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||||
def test_semantic_search_with_reranking(
|
def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
|
||||||
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Test semantic search with reranking model.
|
Test semantic search with reranking model.
|
||||||
|
|
||||||
@ -1439,22 +1466,26 @@ class TestRetrievalService:
|
|||||||
# Simulate reranking changing order
|
# Simulate reranking changing order
|
||||||
reranked_docs = list(reversed(sample_documents))
|
reranked_docs = list(reversed(sample_documents))
|
||||||
|
|
||||||
def side_effect_embedding(
|
def side_effect_retrieve(
|
||||||
flask_app,
|
flask_app,
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
top_k,
|
|
||||||
score_threshold,
|
|
||||||
reranking_model,
|
|
||||||
all_documents,
|
|
||||||
retrieval_method,
|
retrieval_method,
|
||||||
exceptions,
|
dataset,
|
||||||
|
query=None,
|
||||||
|
top_k=4,
|
||||||
|
score_threshold=None,
|
||||||
|
reranking_model=None,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
|
attachment_id=None,
|
||||||
|
all_documents=None,
|
||||||
|
exceptions=None,
|
||||||
):
|
):
|
||||||
# embedding_search handles reranking internally
|
# _retrieve handles reranking internally
|
||||||
all_documents.extend(reranked_docs)
|
if all_documents is not None:
|
||||||
|
all_documents.extend(reranked_docs)
|
||||||
|
|
||||||
mock_embedding_search.side_effect = side_effect_embedding
|
mock_retrieve.side_effect = side_effect_retrieve
|
||||||
|
|
||||||
reranking_model = {
|
reranking_model = {
|
||||||
"reranking_provider_name": "cohere",
|
"reranking_provider_name": "cohere",
|
||||||
@ -1473,7 +1504,7 @@ class TestRetrievalService:
|
|||||||
# Assert
|
# Assert
|
||||||
# For semantic search with reranking, reranking_model should be passed
|
# For semantic search with reranking, reranking_model should be passed
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
call_kwargs = mock_embedding_search.call_args.kwargs
|
call_kwargs = mock_retrieve.call_args.kwargs
|
||||||
assert call_kwargs["reranking_model"] == reranking_model
|
assert call_kwargs["reranking_model"] == reranking_model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from unittest.mock import Mock, PropertyMock, patch
|
from unittest.mock import Mock, PropertyMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response):
|
|||||||
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
|
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
|
||||||
response = Response(mock_response)
|
response = Response(mock_response)
|
||||||
assert response.is_file
|
assert response.is_file
|
||||||
|
|
||||||
|
|
||||||
|
# UTF-8 Encoding Tests
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("content_bytes", "expected_text", "description"),
|
||||||
|
[
|
||||||
|
# Chinese UTF-8 bytes
|
||||||
|
(
|
||||||
|
b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}',
|
||||||
|
'{"message": "你好世界"}',
|
||||||
|
"Chinese characters UTF-8",
|
||||||
|
),
|
||||||
|
# Japanese UTF-8 bytes
|
||||||
|
(
|
||||||
|
b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}',
|
||||||
|
'{"message": "こんにちは"}',
|
||||||
|
"Japanese characters UTF-8",
|
||||||
|
),
|
||||||
|
# Korean UTF-8 bytes
|
||||||
|
(
|
||||||
|
b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}',
|
||||||
|
'{"message": "안녕하세요"}',
|
||||||
|
"Korean characters UTF-8",
|
||||||
|
),
|
||||||
|
# Arabic UTF-8
|
||||||
|
(b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"),
|
||||||
|
# European characters UTF-8
|
||||||
|
(b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"),
|
||||||
|
# Simple ASCII
|
||||||
|
(b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description):
|
||||||
|
"""Test that Response.text properly decodes UTF-8 content with charset_normalizer"""
|
||||||
|
mock_response.headers = {"content-type": "application/json; charset=utf-8"}
|
||||||
|
type(mock_response).content = PropertyMock(return_value=content_bytes)
|
||||||
|
# Mock httpx response.text to return something different (simulating potential encoding issues)
|
||||||
|
mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property
|
||||||
|
|
||||||
|
response = Response(mock_response)
|
||||||
|
|
||||||
|
# Our enhanced text property should decode properly using charset_normalizer
|
||||||
|
assert response.text == expected_text, (
|
||||||
|
f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_property_fallback_to_httpx(mock_response):
|
||||||
|
"""Test that Response.text falls back to httpx.text when charset_normalizer fails"""
|
||||||
|
mock_response.headers = {"content-type": "application/json"}
|
||||||
|
|
||||||
|
# Create malformed UTF-8 bytes
|
||||||
|
malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}'
|
||||||
|
type(mock_response).content = PropertyMock(return_value=malformed_bytes)
|
||||||
|
|
||||||
|
# Mock httpx.text to return some fallback value
|
||||||
|
fallback_text = '{"text": "fallback"}'
|
||||||
|
mock_response.text = fallback_text
|
||||||
|
|
||||||
|
response = Response(mock_response)
|
||||||
|
|
||||||
|
# Should fall back to httpx's text when charset_normalizer fails
|
||||||
|
assert response.text == fallback_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("json_content", "description"),
|
||||||
|
[
|
||||||
|
# JSON with escaped Unicode (like Flask jsonify())
|
||||||
|
('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"),
|
||||||
|
# JSON with mixed escape sequences and UTF-8
|
||||||
|
('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"),
|
||||||
|
# JSON with complex escape sequences
|
||||||
|
('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_property_with_escaped_unicode(mock_response, json_content, description):
|
||||||
|
"""Test Response.text with JSON containing Unicode escape sequences"""
|
||||||
|
mock_response.headers = {"content-type": "application/json"}
|
||||||
|
|
||||||
|
content_bytes = json_content.encode("utf-8")
|
||||||
|
type(mock_response).content = PropertyMock(return_value=content_bytes)
|
||||||
|
mock_response.text = json_content # httpx would return the same for valid UTF-8
|
||||||
|
|
||||||
|
response = Response(mock_response)
|
||||||
|
|
||||||
|
# Should preserve the escape sequences (valid JSON)
|
||||||
|
assert response.text == json_content, f"Failed for {description}"
|
||||||
|
|
||||||
|
# The text should be valid JSON that can be parsed back to proper Unicode
|
||||||
|
parsed = json.loads(response.text)
|
||||||
|
assert isinstance(parsed, dict), f"Invalid JSON for {description}"
|
||||||
|
|||||||
@ -117,7 +117,7 @@ import pytest
|
|||||||
from core.entities.document_task import DocumentTask
|
from core.entities.document_task import DocumentTask
|
||||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Test Data Factory
|
# Test Data Factory
|
||||||
@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Features Property Tests
|
# Features Property Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_features_property(self, mock_feature_service):
|
def test_features_property(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test cached_property features.
|
Test cached_property features.
|
||||||
@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
|
|
||||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_features_property_with_different_tenants(self, mock_feature_service):
|
def test_features_property_with_different_tenants(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test features property with different tenant IDs.
|
Test features property with different tenant IDs.
|
||||||
@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Direct Queue Routing Tests
|
# Direct Queue Routing Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_direct_queue(self, mock_task):
|
def test_send_to_direct_queue(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_direct_queue method.
|
Test _send_to_direct_queue method.
|
||||||
@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||||
def test_send_to_direct_queue_with_priority_task(self, mock_task):
|
def test_send_to_direct_queue_with_priority_task(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_direct_queue with priority task function.
|
Test _send_to_direct_queue with priority task function.
|
||||||
@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_direct_queue_with_single_document(self, mock_task):
|
def test_send_to_direct_queue_with_single_document(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_direct_queue with single document ID.
|
Test _send_to_direct_queue with single document ID.
|
||||||
@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"]
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_direct_queue_with_empty_documents(self, mock_task):
|
def test_send_to_direct_queue_with_empty_documents(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_direct_queue with empty document_ids list.
|
Test _send_to_direct_queue with empty document_ids list.
|
||||||
@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Tenant Queue Routing Tests
|
# Tenant Queue Routing Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_tenant_queue when task key exists.
|
Test _send_to_tenant_queue when task key exists.
|
||||||
@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
|
|
||||||
mock_task.delay.assert_not_called()
|
mock_task.delay.assert_not_called()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_tenant_queue when no task key exists.
|
Test _send_to_tenant_queue when no task key exists.
|
||||||
@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
|
|
||||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_with_priority_task(self, mock_task):
|
def test_send_to_tenant_queue_with_priority_task(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_tenant_queue with priority task function.
|
Test _send_to_tenant_queue with priority task function.
|
||||||
@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_document_task_serialization(self, mock_task):
|
def test_send_to_tenant_queue_document_task_serialization(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test DocumentTask serialization in _send_to_tenant_queue.
|
Test DocumentTask serialization in _send_to_tenant_queue.
|
||||||
@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Queue Type Selection Tests
|
# Queue Type Selection Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_default_tenant_queue(self, mock_task):
|
def test_send_to_default_tenant_queue(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_default_tenant_queue method.
|
Test _send_to_default_tenant_queue method.
|
||||||
@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_priority_tenant_queue method.
|
Test _send_to_priority_tenant_queue method.
|
||||||
@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||||
def test_send_to_priority_direct_queue(self, mock_task):
|
def test_send_to_priority_direct_queue(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_priority_direct_queue method.
|
Test _send_to_priority_direct_queue method.
|
||||||
@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Dispatch Logic Tests
|
# Dispatch Logic Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch method when billing is enabled with SANDBOX plan.
|
Test _dispatch method when billing is enabled with SANDBOX plan.
|
||||||
@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service):
|
def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch method when billing is enabled with TEAM plan.
|
Test _dispatch method when billing is enabled with TEAM plan.
|
||||||
@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service):
|
def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch method when billing is enabled with PROFESSIONAL plan.
|
Test _dispatch method when billing is enabled with PROFESSIONAL plan.
|
||||||
@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch method when billing is disabled.
|
Test _dispatch method when billing is disabled.
|
||||||
@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch method with empty plan string.
|
Test _dispatch method with empty plan string.
|
||||||
@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch method with None plan.
|
Test _dispatch method with None plan.
|
||||||
@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Delay Method Tests
|
# Delay Method Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_delay_method(self, mock_feature_service):
|
def test_delay_method(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test delay method integration.
|
Test delay method integration.
|
||||||
@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_delay_method_with_team_plan(self, mock_feature_service):
|
def test_delay_method_with_team_plan(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test delay method with TEAM plan.
|
Test delay method with TEAM plan.
|
||||||
@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_delay_method_with_billing_disabled(self, mock_feature_service):
|
def test_delay_method_with_billing_disabled(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test delay method with billing disabled.
|
Test delay method with billing disabled.
|
||||||
@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Batch Operations Tests
|
# Batch Operations Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_batch_operation_with_multiple_documents(self, mock_task):
|
def test_batch_operation_with_multiple_documents(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test batch operation with multiple documents.
|
Test batch operation with multiple documents.
|
||||||
@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_batch_operation_with_large_batch(self, mock_task):
|
def test_batch_operation_with_large_batch(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test batch operation with large batch of documents.
|
Test batch operation with large batch of documents.
|
||||||
@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Error Handling Tests
|
# Error Handling Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_direct_queue_task_delay_failure(self, mock_task):
|
def test_send_to_direct_queue_task_delay_failure(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_direct_queue when task.delay() raises an exception.
|
Test _send_to_direct_queue when task.delay() raises an exception.
|
||||||
@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
with pytest.raises(Exception, match="Task delay failed"):
|
with pytest.raises(Exception, match="Task delay failed"):
|
||||||
proxy._send_to_direct_queue(mock_task)
|
proxy._send_to_direct_queue(mock_task)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_push_tasks_failure(self, mock_task):
|
def test_send_to_tenant_queue_push_tasks_failure(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_tenant_queue when push_tasks raises an exception.
|
Test _send_to_tenant_queue when push_tasks raises an exception.
|
||||||
@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
with pytest.raises(Exception, match="Push tasks failed"):
|
with pytest.raises(Exception, match="Push tasks failed"):
|
||||||
proxy._send_to_tenant_queue(mock_task)
|
proxy._send_to_tenant_queue(mock_task)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task):
|
def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task):
|
||||||
"""
|
"""
|
||||||
Test _send_to_tenant_queue when set_task_waiting_time raises an exception.
|
Test _send_to_tenant_queue when set_task_waiting_time raises an exception.
|
||||||
@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
with pytest.raises(Exception, match="Set waiting time failed"):
|
with pytest.raises(Exception, match="Set waiting time failed"):
|
||||||
proxy._send_to_tenant_queue(mock_task)
|
proxy._send_to_tenant_queue(mock_task)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
def test_dispatch_feature_service_failure(self, mock_feature_service):
|
def test_dispatch_feature_service_failure(self, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test _dispatch when FeatureService.get_features raises an exception.
|
Test _dispatch when FeatureService.get_features raises an exception.
|
||||||
@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Integration Tests
|
# Integration Tests
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service):
|
def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test full flow for SANDBOX plan with tenant queue.
|
Test full flow for SANDBOX plan with tenant queue.
|
||||||
@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||||
def test_full_flow_team_plan(self, mock_task, mock_feature_service):
|
def test_full_flow_team_plan(self, mock_task, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test full flow for TEAM plan with priority tenant queue.
|
Test full flow for TEAM plan with priority tenant queue.
|
||||||
@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService")
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task")
|
||||||
def test_full_flow_billing_disabled(self, mock_task, mock_feature_service):
|
def test_full_flow_billing_disabled(self, mock_task, mock_feature_service):
|
||||||
"""
|
"""
|
||||||
Test full flow for billing disabled (self-hosted/enterprise).
|
Test full flow for billing disabled (self-hosted/enterprise).
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
|
|||||||
from core.entities.document_task import DocumentTask
|
from core.entities.document_task import DocumentTask
|
||||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||||
|
|
||||||
|
|
||||||
class DocumentIndexingTaskProxyTestDataFactory:
|
class DocumentIndexingTaskProxyTestDataFactory:
|
||||||
@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_features_property(self, mock_feature_service):
|
def test_features_property(self, mock_feature_service):
|
||||||
"""Test cached_property features."""
|
"""Test cached_property features."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
assert features1 is features2 # Should be the same instance due to caching
|
assert features1 is features2 # Should be the same instance due to caching
|
||||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_direct_queue(self, mock_task):
|
def test_send_to_direct_queue(self, mock_task):
|
||||||
"""Test _send_to_direct_queue method."""
|
"""Test _send_to_direct_queue method."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||||
"""Test _send_to_tenant_queue when task key exists."""
|
"""Test _send_to_tenant_queue when task key exists."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||||
mock_task.delay.assert_not_called()
|
mock_task.delay.assert_not_called()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
@patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task")
|
||||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||||
"""Test _send_to_tenant_queue when no task key exists."""
|
"""Test _send_to_tenant_queue when no task key exists."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
)
|
)
|
||||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
def test_send_to_default_tenant_queue(self):
|
||||||
def test_send_to_default_tenant_queue(self, mock_task):
|
|
||||||
"""Test _send_to_default_tenant_queue method."""
|
"""Test _send_to_default_tenant_queue method."""
|
||||||
# Arrange
|
# Arrange
|
||||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||||
@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
proxy._send_to_default_tenant_queue()
|
proxy._send_to_default_tenant_queue()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
def test_send_to_priority_tenant_queue(self):
|
||||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
|
||||||
"""Test _send_to_priority_tenant_queue method."""
|
"""Test _send_to_priority_tenant_queue method."""
|
||||||
# Arrange
|
# Arrange
|
||||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||||
@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
proxy._send_to_priority_tenant_queue()
|
proxy._send_to_priority_tenant_queue()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
def test_send_to_priority_direct_queue(self):
|
||||||
def test_send_to_priority_direct_queue(self, mock_task):
|
|
||||||
"""Test _send_to_priority_direct_queue method."""
|
"""Test _send_to_priority_direct_queue method."""
|
||||||
# Arrange
|
# Arrange
|
||||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||||
@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
proxy._send_to_priority_direct_queue()
|
proxy._send_to_priority_direct_queue()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||||
"""Test _dispatch method when billing is disabled."""
|
"""Test _dispatch method when billing is disabled."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_delay_method(self, mock_feature_service):
|
def test_delay_method(self, mock_feature_service):
|
||||||
"""Test delay method integration."""
|
"""Test delay method integration."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
assert task.dataset_id == dataset_id
|
assert task.dataset_id == dataset_id
|
||||||
assert task.document_ids == document_ids
|
assert task.document_ids == document_ids
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||||
"""Test _dispatch method with empty plan string."""
|
"""Test _dispatch method with empty plan string."""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy:
|
|||||||
# Assert
|
# Assert
|
||||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||||
"""Test _dispatch method with None plan."""
|
"""Test _dispatch method with None plan."""
|
||||||
# Arrange
|
# Arrange
|
||||||
|
|||||||
@ -0,0 +1,363 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from core.entities.document_task import DocumentTask
|
||||||
|
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||||
|
from enums.cloud_plan import CloudPlan
|
||||||
|
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import (
|
||||||
|
DuplicateDocumentIndexingTaskProxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateDocumentIndexingTaskProxyTestDataFactory:
|
||||||
|
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||||
|
"""Create mock features with billing configuration."""
|
||||||
|
features = Mock()
|
||||||
|
features.billing = Mock()
|
||||||
|
features.billing.enabled = billing_enabled
|
||||||
|
features.billing.subscription = Mock()
|
||||||
|
features.billing.subscription.plan = plan
|
||||||
|
return features
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||||
|
"""Create mock TenantIsolatedTaskQueue."""
|
||||||
|
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||||
|
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||||
|
queue.push_tasks = Mock()
|
||||||
|
queue.set_task_waiting_time = Mock()
|
||||||
|
return queue
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_duplicate_document_task_proxy(
|
||||||
|
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||||
|
) -> DuplicateDocumentIndexingTaskProxy:
|
||||||
|
"""Create DuplicateDocumentIndexingTaskProxy instance for testing."""
|
||||||
|
if document_ids is None:
|
||||||
|
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||||
|
return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDuplicateDocumentIndexingTaskProxy:
|
||||||
|
"""Test cases for DuplicateDocumentIndexingTaskProxy class."""
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test DuplicateDocumentIndexingTaskProxy initialization."""
|
||||||
|
# Arrange
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
dataset_id = "dataset-456"
|
||||||
|
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert proxy._tenant_id == tenant_id
|
||||||
|
assert proxy._dataset_id == dataset_id
|
||||||
|
assert proxy._document_ids == document_ids
|
||||||
|
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||||
|
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||||
|
assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing"
|
||||||
|
|
||||||
|
def test_queue_name(self):
|
||||||
|
"""Test QUEUE_NAME class variable."""
|
||||||
|
# Arrange & Act
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert proxy.QUEUE_NAME == "duplicate_document_indexing"
|
||||||
|
|
||||||
|
def test_task_functions(self):
|
||||||
|
"""Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables."""
|
||||||
|
# Arrange & Act
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task"
|
||||||
|
assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task"
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_features_property(self, mock_feature_service):
|
||||||
|
"""Test cached_property features."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
features1 = proxy.features
|
||||||
|
features2 = proxy.features # Second call should use cached property
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert features1 == mock_features
|
||||||
|
assert features2 == mock_features
|
||||||
|
assert features1 is features2 # Should be the same instance due to caching
|
||||||
|
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||||
|
)
|
||||||
|
def test_send_to_direct_queue(self, mock_task):
|
||||||
|
"""Test _send_to_direct_queue method."""
|
||||||
|
# Arrange
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
mock_task.delay = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._send_to_direct_queue(mock_task)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_task.delay.assert_called_once_with(
|
||||||
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||||
|
)
|
||||||
|
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||||
|
"""Test _send_to_tenant_queue when task key exists."""
|
||||||
|
# Arrange
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||||
|
has_task_key=True
|
||||||
|
)
|
||||||
|
mock_task.delay = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._send_to_tenant_queue(mock_task)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||||
|
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||||
|
assert len(pushed_tasks) == 1
|
||||||
|
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||||
|
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||||
|
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||||
|
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||||
|
mock_task.delay.assert_not_called()
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task"
|
||||||
|
)
|
||||||
|
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||||
|
"""Test _send_to_tenant_queue when no task key exists."""
|
||||||
|
# Arrange
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||||
|
has_task_key=False
|
||||||
|
)
|
||||||
|
mock_task.delay = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._send_to_tenant_queue(mock_task)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||||
|
mock_task.delay.assert_called_once_with(
|
||||||
|
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||||
|
)
|
||||||
|
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||||
|
|
||||||
|
def test_send_to_default_tenant_queue(self):
|
||||||
|
"""Test _send_to_default_tenant_queue method."""
|
||||||
|
# Arrange
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._send_to_default_tenant_queue()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC)
|
||||||
|
|
||||||
|
def test_send_to_priority_tenant_queue(self):
|
||||||
|
"""Test _send_to_priority_tenant_queue method."""
|
||||||
|
# Arrange
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._send_to_priority_tenant_queue()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||||
|
|
||||||
|
def test_send_to_priority_direct_queue(self):
|
||||||
|
"""Test _send_to_priority_direct_queue method."""
|
||||||
|
# Arrange
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_direct_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._send_to_priority_direct_queue()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC)
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||||
|
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||||
|
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||||
|
)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_default_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._dispatch()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||||
|
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||||
|
billing_enabled=True, plan=CloudPlan.TEAM
|
||||||
|
)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_priority_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._dispatch()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||||
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||||
|
"""Test _dispatch method when billing is disabled."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_priority_direct_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._dispatch()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||||
|
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_delay_method(self, mock_feature_service):
|
||||||
|
"""Test delay method integration."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||||
|
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||||
|
)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_default_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy.delay()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||||
|
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||||
|
"""Test _dispatch method with empty plan string."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||||
|
billing_enabled=True, plan=""
|
||||||
|
)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_priority_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._dispatch()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||||
|
"""Test _dispatch method with None plan."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||||
|
billing_enabled=True, plan=None
|
||||||
|
)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_priority_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._dispatch()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
|
|
||||||
|
def test_initialization_with_empty_document_ids(self):
|
||||||
|
"""Test initialization with empty document_ids list."""
|
||||||
|
# Arrange
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
dataset_id = "dataset-456"
|
||||||
|
document_ids = []
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert proxy._tenant_id == tenant_id
|
||||||
|
assert proxy._dataset_id == dataset_id
|
||||||
|
assert proxy._document_ids == document_ids
|
||||||
|
|
||||||
|
def test_initialization_with_single_document_id(self):
|
||||||
|
"""Test initialization with single document_id."""
|
||||||
|
# Arrange
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
dataset_id = "dataset-456"
|
||||||
|
document_ids = ["doc-1"]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert proxy._tenant_id == tenant_id
|
||||||
|
assert proxy._dataset_id == dataset_id
|
||||||
|
assert proxy._document_ids == document_ids
|
||||||
|
|
||||||
|
def test_initialization_with_large_batch(self):
|
||||||
|
"""Test initialization with large batch of document IDs."""
|
||||||
|
# Arrange
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
dataset_id = "dataset-456"
|
||||||
|
document_ids = [f"doc-{i}" for i in range(100)]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert proxy._tenant_id == tenant_id
|
||||||
|
assert proxy._dataset_id == dataset_id
|
||||||
|
assert proxy._document_ids == document_ids
|
||||||
|
assert len(proxy._document_ids) == 100
|
||||||
|
|
||||||
|
@patch("services.document_indexing_proxy.base.FeatureService")
|
||||||
|
def test_dispatch_with_professional_plan(self, mock_feature_service):
|
||||||
|
"""Test _dispatch method when billing is enabled with professional plan."""
|
||||||
|
# Arrange
|
||||||
|
mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||||
|
billing_enabled=True, plan=CloudPlan.PROFESSIONAL
|
||||||
|
)
|
||||||
|
mock_feature_service.get_features.return_value = mock_features
|
||||||
|
proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy()
|
||||||
|
proxy._send_to_priority_tenant_queue = Mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
proxy._dispatch()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||||
@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval:
|
|||||||
|
|
||||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||||
@patch("services.external_knowledge_service.db")
|
@patch("services.external_knowledge_service.db")
|
||||||
def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory):
|
def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory):
|
||||||
"""Test retrieval returns empty list on non-200 status."""
|
"""Test that non-200 status code raises Exception with response text."""
|
||||||
# Arrange
|
# Arrange
|
||||||
binding = factory.create_external_knowledge_binding_mock()
|
binding = factory.create_external_knowledge_binding_mock()
|
||||||
api = factory.create_external_knowledge_api_mock()
|
api = factory.create_external_knowledge_api_mock()
|
||||||
@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval:
|
|||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 500
|
mock_response.status_code = 500
|
||||||
|
mock_response.text = "Internal Server Error: Database connection failed"
|
||||||
mock_process.return_value = mock_response
|
mock_process.return_value = mock_response
|
||||||
|
|
||||||
# Act
|
# Act & Assert
|
||||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
with pytest.raises(Exception, match="Internal Server Error: Database connection failed"):
|
||||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
)
|
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
@pytest.mark.parametrize(
|
||||||
assert result == []
|
("status_code", "error_message"),
|
||||||
|
[
|
||||||
|
(400, "Bad Request: Invalid query parameters"),
|
||||||
|
(401, "Unauthorized: Invalid API key"),
|
||||||
|
(403, "Forbidden: Access denied to resource"),
|
||||||
|
(404, "Not Found: Knowledge base not found"),
|
||||||
|
(429, "Too Many Requests: Rate limit exceeded"),
|
||||||
|
(500, "Internal Server Error: Database connection failed"),
|
||||||
|
(502, "Bad Gateway: External service unavailable"),
|
||||||
|
(503, "Service Unavailable: Maintenance mode"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||||
|
@patch("services.external_knowledge_service.db")
|
||||||
|
def test_fetch_external_knowledge_retrieval_various_error_status_codes(
|
||||||
|
self, mock_db, mock_process, factory, status_code, error_message
|
||||||
|
):
|
||||||
|
"""Test that various error status codes raise exceptions with response text."""
|
||||||
|
# Arrange
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
dataset_id = "dataset-123"
|
||||||
|
|
||||||
|
binding = factory.create_external_knowledge_binding_mock(
|
||||||
|
dataset_id=dataset_id, external_knowledge_api_id="api-123"
|
||||||
|
)
|
||||||
|
api = factory.create_external_knowledge_api_mock(api_id="api-123")
|
||||||
|
|
||||||
|
mock_binding_query = MagicMock()
|
||||||
|
mock_api_query = MagicMock()
|
||||||
|
|
||||||
|
def query_side_effect(model):
|
||||||
|
if model == ExternalKnowledgeBindings:
|
||||||
|
return mock_binding_query
|
||||||
|
elif model == ExternalKnowledgeApis:
|
||||||
|
return mock_api_query
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
mock_db.session.query.side_effect = query_side_effect
|
||||||
|
|
||||||
|
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||||
|
mock_binding_query.first.return_value = binding
|
||||||
|
|
||||||
|
mock_api_query.filter_by.return_value = mock_api_query
|
||||||
|
mock_api_query.first.return_value = api
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
mock_response.text = error_message
|
||||||
|
mock_process.return_value = mock_response
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match=re.escape(error_message)):
|
||||||
|
ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5})
|
||||||
|
|
||||||
|
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||||
|
@patch("services.external_knowledge_service.db")
|
||||||
|
def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory):
|
||||||
|
"""Test exception with empty response text."""
|
||||||
|
# Arrange
|
||||||
|
binding = factory.create_external_knowledge_binding_mock()
|
||||||
|
api = factory.create_external_knowledge_api_mock()
|
||||||
|
|
||||||
|
mock_binding_query = MagicMock()
|
||||||
|
mock_api_query = MagicMock()
|
||||||
|
|
||||||
|
def query_side_effect(model):
|
||||||
|
if model == ExternalKnowledgeBindings:
|
||||||
|
return mock_binding_query
|
||||||
|
elif model == ExternalKnowledgeApis:
|
||||||
|
return mock_api_query
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
mock_db.session.query.side_effect = query_side_effect
|
||||||
|
|
||||||
|
mock_binding_query.filter_by.return_value = mock_binding_query
|
||||||
|
mock_binding_query.first.return_value = binding
|
||||||
|
|
||||||
|
mock_api_query.filter_by.return_value = mock_api_query
|
||||||
|
mock_api_query.first.return_value = api
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 503
|
||||||
|
mock_response.text = ""
|
||||||
|
mock_process.return_value = mock_response
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception, match=""):
|
||||||
|
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
|
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||||
|
)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
|||||||
from enums.cloud_plan import CloudPlan
|
from enums.cloud_plan import CloudPlan
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, Document
|
from models.dataset import Dataset, Document
|
||||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||||
from tasks.document_indexing_task import (
|
from tasks.document_indexing_task import (
|
||||||
_document_indexing,
|
_document_indexing,
|
||||||
_document_indexing_with_tenant_queue,
|
_document_indexing_with_tenant_queue,
|
||||||
@ -138,7 +138,9 @@ class TestTaskEnqueuing:
|
|||||||
with patch.object(DocumentIndexingTaskProxy, "features") as mock_features:
|
with patch.object(DocumentIndexingTaskProxy, "features") as mock_features:
|
||||||
mock_features.billing.enabled = False
|
mock_features.billing.enabled = False
|
||||||
|
|
||||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
# Mock the class variable directly
|
||||||
|
mock_task = Mock()
|
||||||
|
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -163,7 +165,9 @@ class TestTaskEnqueuing:
|
|||||||
mock_features.billing.enabled = True
|
mock_features.billing.enabled = True
|
||||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||||
|
|
||||||
with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task:
|
# Mock the class variable directly
|
||||||
|
mock_task = Mock()
|
||||||
|
with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task):
|
||||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -187,7 +191,9 @@ class TestTaskEnqueuing:
|
|||||||
mock_features.billing.enabled = True
|
mock_features.billing.enabled = True
|
||||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||||
|
|
||||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
# Mock the class variable directly
|
||||||
|
mock_task = Mock()
|
||||||
|
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -211,7 +217,9 @@ class TestTaskEnqueuing:
|
|||||||
mock_features.billing.enabled = True
|
mock_features.billing.enabled = True
|
||||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||||
|
|
||||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
# Mock the class variable directly
|
||||||
|
mock_task = Mock()
|
||||||
|
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -1493,7 +1501,9 @@ class TestEdgeCases:
|
|||||||
mock_features.billing.enabled = True
|
mock_features.billing.enabled = True
|
||||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||||
|
|
||||||
with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task:
|
# Mock the class variable directly
|
||||||
|
mock_task = Mock()
|
||||||
|
with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task):
|
||||||
# Act - Enqueue multiple tasks rapidly
|
# Act - Enqueue multiple tasks rapidly
|
||||||
for doc_ids in document_ids_list:
|
for doc_ids in document_ids_list:
|
||||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids)
|
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids)
|
||||||
@ -1898,7 +1908,7 @@ class TestRobustness:
|
|||||||
- Error is propagated appropriately
|
- Error is propagated appropriately
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features:
|
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features:
|
||||||
# Simulate FeatureService failure
|
# Simulate FeatureService failure
|
||||||
mock_get_features.side_effect = Exception("Feature service unavailable")
|
mock_get_features.side_effect = Exception("Feature service unavailable")
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user