diff --git a/api/.env.example b/api/.env.example index 35aaabbc10..516a119d98 100644 --- a/api/.env.example +++ b/api/.env.example @@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1 # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 + +# Multimodal knowledgebase limit +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 +IMAGE_FILE_BATCH_LIMIT=10 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b5ffd09d01..a5916241df 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings): default=10, ) + IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a image batch upload operation", + default=10, + ) + + SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a single chunk attachment", + default=10, + ) + + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed image file size for attachments in megabytes", + default=2, + ) + + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field( + description="Timeout for downloading image attachments in seconds", + default=60, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 377297c84c..12ada8b798 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel): class MessageFeedbackPayload(BaseModel): message_id: str = Field(..., description="Message ID") rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") @field_validator("message_id") @classmethod @@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource): db.session.delete(feedback) elif args.rating and feedback: feedback.rating = args.rating + feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource): conversation_id=message.conversation_id, message_id=message.id, rating=rating_value, + content=args.content, from_source="admin", from_account_id=current_user.id, ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1fad8abd52..c0422ef6f4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel): external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None icon_info: dict[str, Any] | None = None + is_multimodal: bool | None = False @field_validator("indexing_technique") @classmethod @@ -423,17 +424,16 @@ class DatasetApi(Resource): payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) payload_data = payload.model_dump(exclude_unset=True) current_user, current_tenant_id = current_account_with_tenant() - # check embedding model setting if ( payload.indexing_technique == "high_quality" and payload.embedding_model_provider is not None and payload.embedding_model is not None ): - DatasetService.check_embedding_model_setting( + is_multimodal = DatasetService.check_is_multimodal_model( dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model ) - + payload.is_multimodal = is_multimodal # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, payload.permission, payload.partial_member_list diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 2520111281..6145da31a5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -424,6 +424,10 @@ class DatasetInitApi(Resource): model_type=ModelType.TEXT_EMBEDDING, model=knowledge_config.embedding_model, ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model + ) + knowledge_config.is_multimodal = is_multimodal except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ee390cbfb7..e73abc2555 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel): content: str answer: str | None = None keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdatePayload(BaseModel): @@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel): answer: str | None = None keywords: list[str] | None = None regenerate_child_chunks: bool = False + attachment_ids: list[str] | None = None class BatchImportPayload(BaseModel): diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index fac90a0135..db7c50f422 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal +from flask_restx import marshal, reqparse from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel): query: str = Field(max_length=250) retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None + attachment_ids: list[str] | None = None class DatasetsHitTestingBase: @@ -54,16 +55,28 @@ class DatasetsHitTestingBase: def hit_testing_args_check(args: dict[str, Any]): HitTestingService.hit_testing_args_check(args) + @staticmethod + def parse_args(): + parser = ( + reqparse.RequestParser() + .add_argument("query", type=str, required=False, location="json") + .add_argument("attachment_ids", type=list, required=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") + ) + return parser.parse_args() + @staticmethod def perform_hit_testing(dataset, args): assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, - query=args["query"], + query=args.get("query"), account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], + retrieval_model=args.get("retrieval_model"), + external_retrieval_model=args.get("external_retrieval_model"), + attachment_ids=args.get("attachment_ids"), limit=10, ) return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index fdd7c2f479..29417dc896 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -45,6 +45,9 @@ class FileApi(Resource): "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, }, 200 @setup_required diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 69281c6214..268473d6d1 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService from .. import console_ns -from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from ..wraps import ( + account_initialization_required, + edit_permission_required, + is_admin_or_owner_required, + setup_required, +) logger = logging.getLogger(__name__) @@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource): class TriggerSubscriptionListApi(Resource): @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def get(self, provider): """List all trigger subscriptions for the current tenant's provider""" @@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): @console_ns.expect(parser) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider): """Add a new subscription instance for a trigger provider""" @@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): class TriggerSubscriptionBuilderGetApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get a subscription instance for a trigger provider""" @@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): @console_ns.expect(parser_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Verify a subscription instance for a trigger provider""" @@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Update a subscription instance for a trigger provider""" @@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): class TriggerSubscriptionBuilderLogsApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get the request logs for a subscription instance for a trigger provider""" @@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Build a subscription instance for a trigger provider""" diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 9a9832dd4a..e2e6c11480 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -83,6 +83,7 @@ class AppRunner: context: str | None = None, memory: TokenBufferMemory | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages @@ -111,6 +112,7 @@ class AppRunner: memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 53188cf506..f8338b226b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent @@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner): ) dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner): context=context, memory=memory, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index e2be4146e1..ddfb5725b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError @@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner): query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner): query=query, context=context, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 14d5f38dcd..d0279349ca 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: document_id, ) continue - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunk_stmt = select(ChildChunk).where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 36b38b7b45..59de4f403d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -7,7 +7,7 @@ import time import uuid from typing import Any -from flask import current_app +from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -89,8 +90,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -136,7 +146,7 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() @@ -152,8 +162,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -209,7 +228,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -302,6 +321,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) documents = index_processor.transform( text_docs, + current_user=None, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), tenant_id=tenant_id, @@ -551,7 +571,10 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 create_keyword_thread = None - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + ): # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, @@ -590,7 +613,7 @@ class IndexingRunner: for future in futures: tokens += future.result() if ( - dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy" and create_keyword_thread is not None ): @@ -635,7 +658,13 @@ class IndexingRunner: db.session.commit() def _process_chunk( - self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + self, + flask_app: Flask, + index_processor: BaseIndexProcessor, + chunk_documents: list[Document], + dataset: Dataset, + dataset_document: DatasetDocument, + embedding_model_instance: ModelInstance | None, ): with flask_app.app_context(): # check document is paused @@ -646,8 +675,15 @@ class IndexingRunner: page_content_list = [document.page_content for document in chunk_documents] tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) + multimodal_documents = [] + for document in chunk_documents: + if document.attachments and dataset.is_multimodal: + multimodal_documents.extend(document.attachments) + # load index - index_processor.load(dataset, chunk_documents, with_keywords=False) + index_processor.load( + dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False + ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).where( @@ -710,6 +746,7 @@ class IndexingRunner: text_docs: list[Document], doc_language: str, process_rule: dict, + current_user: Account | None = None, ) -> list[Document]: # get embedding model instance embedding_model_instance = None @@ -729,6 +766,7 @@ class IndexingRunner: documents = index_processor.transform( text_docs, + current_user, embedding_model_instance=embedding_model_instance, process_rule=process_rule, tenant_id=dataset.tenant_id, @@ -737,14 +775,16 @@ class IndexingRunner: return documents - def _load_segments(self, dataset, dataset_document, documents): + def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]): # save node to document segment doc_store = DatasetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments - doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) + doc_store.add_documents( + docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + ) # update document status to indexing cur_time = naive_utc_now() diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a63e94d59c..5a28bbcc3a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel @@ -200,7 +200,7 @@ class ModelInstance: def invoke_text_embedding( self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke large language model @@ -212,7 +212,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return cast( - TextEmbeddingResult, + EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, model=self.model, @@ -223,6 +223,34 @@ class ModelInstance: ), ) + def invoke_multimodal_embedding( + self, + multimodel_documents: list[dict], + user: str | None = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> EmbeddingResult: + """ + Invoke large language model + + :param multimodel_documents: multimodel documents to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + return cast( + EmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + user=user, + input_type=input_type, + ), + ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding @@ -276,6 +304,40 @@ class ModelInstance: ), ) + def invoke_multimodal_rerank( + self, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if not isinstance(self.model_type_instance, RerankModel): + raise Exception("Model type instance is not RerankModel") + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), + ) + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -461,6 +523,32 @@ class ModelManager: model=default_model_entity.model, ) + def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool: + """ + Check if model supports vision + :param tenant_id: tenant id + :param provider: provider name + :param model: model name + :return: True if model supports vision, False otherwise + """ + model_instance = self.get_model_instance(tenant_id, provider, model_type, model) + model_type_instance = model_instance.model_type_instance + match model_type: + case ModelType.LLM: + model_type_instance = cast(LargeLanguageModel, model_type_instance) + case ModelType.TEXT_EMBEDDING: + model_type_instance = cast(TextEmbeddingModel, model_type_instance) + case ModelType.RERANK: + model_type_instance = cast(RerankModel, model_type_instance) + case _: + raise ValueError(f"Model type {model_type} is not supported") + model_schema = model_type_instance.get_model_schema(model, model_instance.credentials) + if not model_schema: + return False + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + return False + class LBModelManager: def __init__( diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 846b89d658..854c448250 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage): latency: float -class TextEmbeddingResult(BaseModel): +class EmbeddingResult(BaseModel): """ Model class for text embedding result. """ @@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel): model: str embeddings: list[list[float]] usage: EmbeddingUsage + + +class FileEmbeddingResult(BaseModel): + """ + Model class for file embedding result. + """ + + model: str + embeddings: list[list[float]] + usage: EmbeddingUsage diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 36067118b0..0a576b832a 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -50,3 +50,43 @@ class RerankModel(AIModel): ) except Exception as e: raise self._transform_invoke_error(e) + + def invoke_multimodal_rerank( + self, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index bd68ffe903..4c902e2c11 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,7 +2,7 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel): self, model: str, credentials: dict, - texts: list[str], + texts: list[str] | None = None, + multimodel_documents: list[dict] | None = None, user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed + :param files: files to embed :param user: unique user id :param input_type: input type :return: embeddings result @@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel): try: plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) + if texts: + return plugin_model_manager.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + if multimodel_documents: + return plugin_model_manager.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + documents=multimodel_documents, + input_type=input_type, + ) + raise ValueError("No texts or files provided") except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5dfc3c212e..5d70980967 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient): credentials: dict, texts: list[str], input_type: str, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type_=TextEmbeddingResult, + type_=EmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke text embedding") + def invoke_multimodal_embedding( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + documents: list[dict], + input_type: str, + ) -> EmbeddingResult: + """ + Invoke file embedding + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", + type_=EmbeddingResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "text-embedding", + "model": model, + "credentials": credentials, + "documents": documents, + "input_type": input_type, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("Failed to invoke file embedding") + def get_text_embedding_num_tokens( self, tenant_id: str, @@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke rerank") + def invoke_multimodal_rerank( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", + type_=RerankResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "rerank", + "model": model, + "credentials": credentials, + "query": query, + "docs": docs, + "score_threshold": score_threshold, + "top_n": top_n, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise ValueError("Failed to invoke multimodal rerank") + def invoke_tts( self, tenant_id: str, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d1d518a55d..f072092ea7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} @@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) return prompt_messages, stops @@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self._get_last_user_message(query, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files)) else: - prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files)) return prompt_messages, None @@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( @@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self._get_last_user_message(prompt, files, image_detail_config)], stops + return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops def _get_last_user_message( self, prompt: str, files: Sequence["File"], image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> UserPromptMessage: + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + if context_files: + for file in context_files: + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) + if prompt_message_contents: prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index cc946a72c3..bfa8781e9f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -30,9 +31,10 @@ class DataPostProcessor: score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2290de19bc..cbd7cbeb64 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,23 +1,30 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from typing import Any from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -37,14 +44,15 @@ class RetrievalService: retrieval_method: RetrievalMethod, dataset_id: str, query: str, - top_k: int, + top_k: int = 4, score_threshold: float | None = 0.0, reranking_model: dict | None = None, reranking_mode: str = "reranking_model", weights: dict | None = None, document_ids_filter: list[str] | None = None, + attachment_ids: list | None = None, ): - if not query: + if not query and not attachment_ids: return [] dataset = cls._get_dataset(dataset_id) if not dataset: @@ -56,69 +64,52 @@ class RetrievalService: # Optimize multithreading with thread pools with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore futures = [] - if retrieval_method == RetrievalMethod.KEYWORD_SEARCH: + retrieval_service = RetrievalService() + if query: futures.append( executor.submit( - cls.keyword_search, + retrieval_service._retrieve, flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - all_documents=all_documents, - exceptions=exceptions, - document_ids_filter=document_ids_filter, - ) - ) - if RetrievalMethod.is_support_semantic_search(retrieval_method): - futures.append( - executor.submit( - cls.embedding_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, + retrieval_method=retrieval_method, + dataset=dataset, query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, + reranking_mode=reranking_mode, + weights=weights, document_ids_filter=document_ids_filter, + attachment_id=None, + all_documents=all_documents, + exceptions=exceptions, ) ) - if RetrievalMethod.is_support_fulltext_search(retrieval_method): - futures.append( - executor.submit( - cls.full_text_index_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, - document_ids_filter=document_ids_filter, + if attachment_ids: + for attachment_id in attachment_ids: + futures.append( + executor.submit( + retrieval_service._retrieve, + flask_app=current_app._get_current_object(), # type: ignore + retrieval_method=retrieval_method, + dataset=dataset, + query=None, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=reranking_mode, + weights=weights, + document_ids_filter=document_ids_filter, + attachment_id=attachment_id, + all_documents=all_documents, + exceptions=exceptions, + ) ) - ) - concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) + + concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) if exceptions: raise ValueError(";\n".join(exceptions)) - # Deduplicate documents for hybrid search to avoid duplicate chunks - if retrieval_method == RetrievalMethod.HYBRID_SEARCH: - all_documents = cls._deduplicate_documents(all_documents) - data_post_processor = DataPostProcessor( - str(dataset.tenant_id), reranking_mode, reranking_model, weights, False - ) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k, - ) - return all_documents @classmethod @@ -223,6 +214,7 @@ class RetrievalService: retrieval_method: RetrievalMethod, exceptions: list, document_ids_filter: list[str] | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ): with flask_app.app_context(): try: @@ -231,14 +223,30 @@ class RetrievalService: raise ValueError("dataset not found") vector = Vector(dataset=dataset) - documents = vector.search_by_vector( - query, - search_type="similarity_score_threshold", - top_k=top_k, - score_threshold=score_threshold, - filter={"group_id": [dataset.id]}, - document_ids_filter=document_ids_filter, - ) + documents = [] + if query_type == QueryType.TEXT_QUERY: + documents.extend( + vector.search_by_vector( + query, + search_type="similarity_score_threshold", + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) + if query_type == QueryType.IMAGE_QUERY: + if not dataset.is_multimodal: + return + documents.extend( + vector.search_by_file( + file_id=query, + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) if documents: if ( @@ -250,14 +258,37 @@ class RetrievalService: data_post_processor = DataPostProcessor( str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) - all_documents.extend( - data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents), + if dataset.is_multimodal: + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=dataset.tenant_id, + provider=reranking_model.get("reranking_provider_name") or "", + model=reranking_model.get("reranking_model_name") or "", + model_type=ModelType.RERANK, + ) + if is_support_vision: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) + ) + else: + # not effective, return original documents + all_documents.extend(documents) + else: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) ) - ) else: all_documents.extend(documents) except Exception as e: @@ -339,103 +370,159 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - - # Process documents - for document in documents: - document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - - dataset_document = dataset_documents[document_id] - if not dataset_document: - continue - - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents - child_index_node_id = document.metadata.get("doc_id") - child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) - child_chunk = db.session.scalar(child_chunk_stmt) - - if not child_chunk: + segment_file_map = {} + with Session(db.engine) as session: + # Process documents + for document in documents: + segment_id = None + attachment_info = None + child_chunk = None + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: continue - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == child_chunk.segment_id, - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, + dataset_document = dataset_documents[document_id] + if not dataset_document: + continue + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + # Handle parent-child documents + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attchment_info"] + segment_id = attachment_info_dict["segment_id"] + else: + child_index_node_id = document.metadata.get("doc_id") + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = session.scalar(child_chunk_stmt) + + if not child_chunk: + continue + segment_id = child_chunk.segment_id + + if not segment_id: + continue + + segment = ( + session.query(DocumentSegment) + .where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + .options( + load_only( + DocumentSegment.id, + DocumentSegment.content, + DocumentSegment.answer, + ) + ) + .first() ) - .first() - ) - if not segment: - continue + if not segment: + continue - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + if attachment_info: + segment_file_map[segment.id].append(attachment_info) else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - else: - # Handle normal documents - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - segment = db.session.scalar(document_segment_stmt) + # Handle normal documents + segment = None + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attchment_info"] + segment_id = attachment_info_dict["segment_id"] + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + segment = db.session.scalar(document_segment_stmt) + if segment: + segment_file_map[segment.id] = [attachment_info] + else: + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + segment = db.session.scalar(document_segment_stmt) - if not segment: - continue - - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score"), # type: ignore - } - records.append(record) + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score"), # type: ignore + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if attachment_info: + attachment_infos = segment_file_map.get(segment.id, []) + if attachment_info not in attachment_infos: + attachment_infos.append(attachment_info) + segment_file_map[segment.id] = attachment_infos # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] + if record["segment"].id in segment_file_map: + record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] result = [] for record in records: @@ -447,6 +534,11 @@ class RetrievalService: if not isinstance(child_chunks, list): child_chunks = None + # Extract files, ensuring it's a list or None + files = record.get("files") + if not isinstance(files, list): + files = None + # Extract score, ensuring it's a float or None score_value = record.get("score") score = ( @@ -456,10 +548,149 @@ class RetrievalService: ) # Create RetrievalSegments object - retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) + retrieval_segment = RetrievalSegments( + segment=segment, child_chunks=child_chunks, score=score, files=files + ) result.append(retrieval_segment) return result except Exception as e: db.session.rollback() raise e + + def _retrieve( + self, + flask_app: Flask, + retrieval_method: RetrievalMethod, + dataset: Dataset, + query: str | None = None, + top_k: int = 4, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, + reranking_mode: str = "reranking_model", + weights: dict | None = None, + document_ids_filter: list[str] | None = None, + attachment_id: str | None = None, + all_documents: list[Document] = [], + exceptions: list[str] = [], + ): + if not query and not attachment_id: + return + with flask_app.app_context(): + all_documents_item: list[Document] = [] + # Optimize multithreading with thread pools + with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore + futures = [] + if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query: + futures.append( + executor.submit( + self.keyword_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + all_documents=all_documents_item, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + if query: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.TEXT_QUERY, + ) + ) + if attachment_id: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=attachment_id, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.IMAGE_QUERY, + ) + ) + if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query: + futures.append( + executor.submit( + self.full_text_index_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + + if exceptions: + raise ValueError(";\n".join(exceptions)) + + # Deduplicate documents for hybrid search to avoid duplicate chunks + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: + if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE: + all_documents.extend(all_documents_item) + all_documents_item = self._deduplicate_documents(all_documents_item) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) + + query = query or attachment_id + if not query: + return + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, + ) + + all_documents.extend(all_documents_item) + + @classmethod + def get_segment_attachment_info( + cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session + ) -> dict[str, Any] | None: + upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() + if upload_file: + attachment_binding = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.attachment_id == upload_file.id) + .first() + ) + if attachment_binding: + attchment_info = { + "id": upload_file.id, + "name": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "size": upload_file.size, + } + return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id} + return None diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0beb388693..3a47241293 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,4 @@ +import base64 import logging import time from abc import ABC, abstractmethod @@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings +from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.dataset import Dataset, Whitelist +from models.model import UploadFile logger = logging.getLogger(__name__) @@ -203,6 +207,47 @@ class Vector: self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) + def create_multimodal(self, file_documents: list | None = None, **kwargs): + if file_documents: + start = time.time() + logger.info("start embedding %s files %s", len(file_documents), start) + batch_size = 1000 + total_batches = len(file_documents) + batch_size - 1 + for i in range(0, len(file_documents), batch_size): + batch = file_documents[i : i + batch_size] + batch_start = time.time() + logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch)) + + # Batch query all upload files to avoid N+1 queries + attachment_ids = [doc.metadata["doc_id"] for doc in batch] + stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids)) + upload_files = db.session.scalars(stmt).all() + upload_file_map = {str(f.id): f for f in upload_files} + + file_base64_list = [] + real_batch = [] + for document in batch: + attachment_id = document.metadata["doc_id"] + doc_type = document.metadata["doc_type"] + upload_file = upload_file_map.get(attachment_id) + if upload_file: + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + file_base64_list.append( + { + "content": file_base64_str, + "content_type": doc_type, + "file_id": attachment_id, + } + ) + real_batch.append(document) + batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list) + logger.info( + "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start + ) + self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs) + logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start) + def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) @@ -223,6 +268,22 @@ class Vector: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) + def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + return [] + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + multimodal_vector = self._embeddings.embed_multimodal_query( + { + "content": file_base64_str, + "content_type": DocType.IMAGE, + "file_id": file_id, + } + ) + return self._vector_processor.search_by_vector(multimodal_vector, **kwargs) + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 2c7bc592c0..84d1e26b34 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -79,6 +79,18 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes + def __del__(self): + """ + Destructor to properly close the Weaviate client connection. + Prevents connection leaks and resource warnings. + """ + if hasattr(self, "_client") and self._client is not None: + try: + self._client.close() + except Exception as e: + # Ignore errors during cleanup as object is being destroyed + logger.warning("Error closing Weaviate client %s", e, exc_info=True) + def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 74a2653e9d..1fe74d3042 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -5,9 +5,9 @@ from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding class DatasetDocumentStore: @@ -120,6 +120,9 @@ class DatasetDocumentStore: db.session.add(segment_document) db.session.flush() + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): @@ -144,6 +147,9 @@ class DatasetDocumentStore: segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child and doc.children: # delete the existing child chunks db.session.query(ChildChunk).where( @@ -233,3 +239,15 @@ class DatasetDocumentStore: document_segment = db.session.scalar(stmt) return document_segment + + def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): + if multimodel_documents: + for multimodel_document in multimodel_documents: + binding = SegmentAttachmentBinding( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_id, + attachment_id=multimodel_document.metadata["doc_id"], + ) + db.session.add(binding) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 7fb20c1941..3cbc7db75d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings): return text_embeddings + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + # use doc embedding cache or store if not exists + multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] + embedding_queue_indices = [] + for i, multimodel_document in enumerate(multimodel_documents): + file_id = multimodel_document["file_id"] + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider + ) + .first() + ) + if embedding: + multimodel_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + + # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work + + if embedding_queue_indices: + embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) + for i in range(0, len(embedding_queue_multimodel_documents), max_chunks): + batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks] + + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=batch_multimodel_documents, + user=self._user, + input_type=EmbeddingInputType.DOCUMENT, + ) + + for vector in embedding_result.embeddings: + try: + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning("Normalized embedding is nan: %s", normalized_embedding) + continue + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception: + logger.exception("Failed transform embedding") + cache_embeddings = [] + try: + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + multimodel_embeddings[i] = n_embedding + file_id = multimodel_documents[i]["file_id"] + if file_id not in cache_embeddings: + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=file_id, + provider_name=self._model_instance.provider, + embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), + ) + embedding_cache.set_embedding(n_embedding) + db.session.add(embedding_cache) + cache_embeddings.append(file_id) + db.session.commit() + except IntegrityError: + db.session.rollback() + except Exception as ex: + db.session.rollback() + logger.exception("Failed to embed documents") + raise ex + + return multimodel_embeddings + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists @@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings): raise ex return embedding_results # type: ignore + + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal documents.""" + # use doc embedding cache or store if not exists + file_id = multimodel_document["file_id"] + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}" + embedding = redis_client.get(embedding_cache_key) + if embedding: + redis_client.expire(embedding_cache_key, 600) + decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") + return [float(x) for x in decoded_embedding] + try: + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + ) + + embedding_results = embedding_result.embeddings[0] + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") + except Exception as ex: + if dify_config.DEBUG: + logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"]) + raise ex + + try: + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 600, encoded_str) + except Exception as ex: + if dify_config.DEBUG: + logger.exception( + "Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"] + ) + raise ex + + return embedding_results # type: ignore diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 9f232ab910..1be55bda80 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -9,11 +9,21 @@ class Embeddings(ABC): """Embed search docs.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + raise NotImplementedError + @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal query.""" + raise NotImplementedError + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" raise NotImplementedError diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 8e92191568..b54a37b49e 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None + files: list[dict[str, str | int]] | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index aca879df7d..9f66cd9a03 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel): page: int | None = None doc_metadata: dict[str, Any] | None = None title: str | None = None + files: list[dict[str, Any]] | None = None diff --git a/api/core/rag/index_processor/constant/doc_type.py b/api/core/rag/index_processor/constant/doc_type.py new file mode 100644 index 0000000000..93c8fecb8d --- /dev/null +++ b/api/core/rag/index_processor/constant/doc_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class DocType(StrEnum): + TEXT = "text" + IMAGE = "image" diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index 659086e808..09617413f7 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,7 +1,12 @@ from enum import StrEnum -class IndexType(StrEnum): +class IndexStructureType(StrEnum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" PARENT_CHILD_INDEX = "hierarchical_model" + + +class IndexTechniqueType(StrEnum): + ECONOMY = "economy" + HIGH_QUALITY = "high_quality" diff --git a/api/core/rag/index_processor/constant/query_type.py b/api/core/rag/index_processor/constant/query_type.py new file mode 100644 index 0000000000..342bfef3f7 --- /dev/null +++ b/api/core/rag/index_processor/constant/query_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class QueryType(StrEnum): + TEXT_QUERY = "text_query" + IMAGE_QUERY = "image_query" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d4eff53204..8a28eb477a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,20 +1,34 @@ """Abstract interface for document loader implementations.""" +import cgi +import logging +import mimetypes +import os +import re from abc import ABC, abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional +from urllib.parse import unquote, urlparse + +import httpx from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.models.document import Document +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import AttachmentDocument, Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter +from extensions.ext_database import db +from extensions.ext_storage import storage +from models import Account, ToolFile from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from models.model import UploadFile if TYPE_CHECKING: from core.model_manager import ModelInstance @@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): raise NotImplementedError @abstractmethod @@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC): ) return character_splitter # type: ignore + + def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]: + """ + Get the content files from the document. + """ + multi_model_documents: list[AttachmentDocument] = [] + text = document.page_content + images = self._extract_markdown_images(text) + if not images: + return multi_model_documents + upload_file_id_list = [] + + for image in images: + # Collect all upload_file_ids including duplicates to preserve occurrence count + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + # Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + if current_user: + tool_file_id = match.group(1) + upload_file_id = self._download_tool_file(tool_file_id, current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + continue + if current_user: + upload_file_id = self._download_image(image.split(" ")[0], current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + + if not upload_file_id_list: + return multi_model_documents + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + + # Create a mapping from ID to UploadFile for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} + + # Create a Document for each occurrence (including duplicates) + for upload_file_id in upload_file_id_list: + upload_file = upload_file_map.get(upload_file_id) + if upload_file: + multi_model_documents.append( + AttachmentDocument( + page_content=upload_file.name, + metadata={ + "doc_id": upload_file.id, + "doc_hash": "", + "document_id": document.metadata.get("document_id"), + "dataset_id": document.metadata.get("dataset_id"), + "doc_type": DocType.IMAGE, + }, + ) + ) + return multi_model_documents + + def _extract_markdown_images(self, text: str) -> list[str]: + """ + Extract the markdown images from the text. + """ + pattern = r"!\[.*?\]\((.*?)\)" + return re.findall(pattern, text) + + def _download_image(self, image_url: str, current_user: Account) -> str | None: + """ + Download the image from the URL. + Image size must not exceed 2MB. + """ + from services.file_service import FileService + + MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT + + try: + # Download with timeout + response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT) + response.raise_for_status() + + # Check Content-Length header if available + content_length = response.headers.get("Content-Length") + if content_length and int(content_length) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length) + return None + + filename = None + + content_disposition = response.headers.get("content-disposition") + if content_disposition: + _, params = cgi.parse_header(content_disposition) + if "filename" in params: + filename = params["filename"] + filename = unquote(filename) + + if not filename: + parsed_url = urlparse(image_url) + # unquote 处理 URL 中的中文 + path = unquote(parsed_url.path) + filename = os.path.basename(path) + + if not filename: + filename = "downloaded_image_file" + + name, current_ext = os.path.splitext(filename) + + content_type = response.headers.get("content-type", "").split(";")[0].strip() + + real_ext = mimetypes.guess_extension(content_type) + + if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext: + filename = f"{name}{real_ext}" + # Download content with size limit + blob = b"" + for chunk in response.iter_bytes(chunk_size=8192): + blob += chunk + if len(blob) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit during download", image_url) + return None + + if not blob: + logging.warning("Image from %s is empty", image_url) + return None + + upload_file = FileService(db.engine).upload_file( + filename=filename, + content=blob, + mimetype=content_type, + user=current_user, + ) + return upload_file.id + except httpx.TimeoutException: + logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT) + return None + except httpx.RequestError as e: + logging.warning("Error downloading image from %s: %s", image_url, str(e)) + return None + except Exception: + logging.exception("Unexpected error downloading image from %s", image_url) + return None + + def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None: + """ + Download the tool file from the ID. + """ + from services.file_service import FileService + + tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + if not tool_file: + return None + blob = storage.load_once(tool_file.file_key) + upload_file = FileService(db.engine).upload_file( + filename=tool_file.name, + content=blob, + mimetype=tool_file.mimetype, + user=current_user, + ) + return upload_file.id diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c987edf342..ea6ab24699 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor @@ -19,11 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX: + if self._index_type == IndexStructureType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX: + elif self._index_type == IndexStructureType.QA_INDEX: return QAIndexProcessor() - elif self._index_type == IndexType.PARENT_CHILD_INDEX: + elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX: return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5e5fea7ea9..a7c879f2c4 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule @@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if document_node.metadata is not None: document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_hash"] = hash + multimodal_documents = ( + self._get_content_files(document_node, current_user) if document_node.metadata else None + ) + if multimodal_documents: + document_node.attachments = multimodal_documents # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: @@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) with_keywords = False if with_keywords: keywords_list = kwargs.get("keywords_list") @@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + documents: list[Any] = [] + all_multimodal_documents: list[Any] = [] if isinstance(chunks, list): - documents = [] for content in chunks: metadata = { "dataset_id": dataset.id, @@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor): "doc_hash": helper.generate_text_hash(content), } doc = Document(page_content=content, metadata=metadata) + attachments = self._get_content_files(doc) + if attachments: + doc.attachments = attachments + all_multimodal_documents.extend(attachments) documents.append(doc) - if documents: - # save node to document segment - doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) - # add document segments - doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": - vector = Vector(dataset) - vector.create(documents) - elif dataset.indexing_technique == "economy": - keyword = Keyword(dataset) - keyword.add_texts(documents) else: - raise ValueError("Chunks is not a list") + multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks) + for general_chunk in multimodal_general_structure.general_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(general_chunk.content), + } + doc = Document(page_content=general_chunk.content, metadata=metadata) + if general_chunk.files: + attachments = [] + for file in general_chunk.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument( + page_content=file.filename or "image_file", metadata=file_metadata + ) + attachments.append(file_document) + all_multimodal_documents.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + if all_multimodal_documents: + vector.create_multimodal(all_multimodal_documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)} + return { + "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, + "preview": preview, + "total_segments": len(chunks), + } else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 4fa78e2f95..ee29d2fd65 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk +from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from libs import helper +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): page_content = page_content if len(page_content) > 0: document_node.page_content = page_content + multimodel_documents = self._get_content_files(document_node, current_user) + if multimodel_documents: + document_node.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): elif rules.parent_mode == ParentMode.FULL_DOC: page_content = "\n".join([document.page_content for document in documents]) document = Document(page_content=page_content, metadata=documents[0].metadata) + multimodel_documents = self._get_content_files(document) + if multimodel_documents: + document.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) for document in documents: @@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids @@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + if parent_child.files and len(parent_child.files) > 0: + attachments = [] + for file in parent_child.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata) + attachments.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) documents.append(doc) if documents: # update document parent mode @@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store.add_documents(docs=documents, save_child=True) if dataset.indexing_technique == "high_quality": all_child_documents = [] + all_multimodal_documents = [] for doc in documents: if doc.children: all_child_documents.extend(doc.children) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + vector = Vector(dataset) if all_child_documents: - vector = Vector(dataset) vector.create(all_child_documents) + if all_multimodal_documents: + vector.create_multimodal(all_multimodal_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: parent_childs = ParentChildStructureChunk.model_validate(chunks) @@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) return { - "chunk_structure": IndexType.PARENT_CHILD_INDEX, + "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 3e3deb0180..1183d5fbd7 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document, QAStructureChunk +from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor): ) return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") if not process_rule: @@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor): try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file) # type: ignore text_docs = [] for _, row in df.iterrows(): data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]}) @@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) @@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor): for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) return { - "chunk_structure": IndexType.QA_INDEX, + "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4bd7b1d62e..611fad9a18 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,6 +4,8 @@ from typing import Any from pydantic import BaseModel, Field +from core.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" @@ -15,7 +17,19 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AttachmentDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + provider: str | None = "dify" + + vector: list[float] | None = None + + metadata: dict[str, Any] = Field(default_factory=dict) class Document(BaseModel): @@ -28,12 +42,31 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) provider: str | None = "dify" children: list[ChildDocument] | None = None + attachments: list[AttachmentDocument] | None = None + + +class GeneralChunk(BaseModel): + """ + General Chunk. + """ + + content: str + files: list[File] | None = None + + +class MultimodalGeneralStructureChunk(BaseModel): + """ + Multimodal General Structure Chunk. + """ + + general_chunks: list[GeneralChunk] + class GeneralStructureChunk(BaseModel): """ @@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel): parent_content: str child_contents: list[str] + files: list[File] | None = None class ParentChildStructureChunk(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 3561def008..88acb75133 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -12,6 +13,7 @@ class BaseRerankRunner(ABC): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index e855b0083f..38309d3d77 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,6 +1,15 @@ -from core.model_manager import ModelInstance +import base64 + +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.rerank_entities import RerankResult +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile class RerankModelRunner(BaseRerankRunner): @@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, + provider=self.rerank_model_instance.provider, + model=self.rerank_model_instance.model, + model_type=ModelType.RERANK, + ) + if not is_support_vision: + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + else: + return documents + else: + rerank_result, unique_documents = self.fetch_multimodal_rerank( + query, documents, score_threshold, top_n, user, query_type + ) + + rerank_documents = [] + for result in rerank_result.docs: + if score_threshold is None or result.score >= score_threshold: + # format document + rerank_document = Document( + page_content=result.text, + metadata=unique_documents[result.index].metadata, + provider=unique_documents[result.index].provider, + ) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) + return rerank_documents[:top_n] if top_n else rerank_documents + + def fetch_text_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch text rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ docs = [] doc_ids = set() unique_documents = [] @@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - docs.append(document.page_content) - unique_documents.append(document) + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + docs.append(document.page_content) + unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append(document.page_content) unique_documents.append(document) - documents = unique_documents - rerank_result = self.rerank_model_instance.invoke_rerank( query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) + return rerank_result, unique_documents - rerank_documents = [] + def fetch_multimodal_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch multimodal rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :param query_type: query type + :return: rerank result + """ + docs = [] + doc_ids = set() + unique_documents = [] + for document in documents: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): + if document.metadata.get("doc_type") == DocType.IMAGE: + # Query file info within db.session context to ensure thread-safe access + upload_file = ( + db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() + ) + if upload_file: + blob = storage.load_once(upload_file.key) + document_file_base64 = base64.b64encode(blob).decode() + document_file_dict = { + "content": document_file_base64, + "content_type": document.metadata["doc_type"], + } + docs.append(document_file_dict) + else: + document_text_dict = { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + docs.append(document_text_dict) + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append( + { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + ) + unique_documents.append(document) - for result in rerank_result.docs: - if score_threshold is None or result.score >= score_threshold: - # format document - rerank_document = Document( - page_content=result.text, - metadata=documents[result.index].metadata, - provider=documents[result.index].provider, + documents = unique_documents + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + return rerank_result, unique_documents + elif query_type == QueryType.IMAGE_QUERY: + # Query file info within db.session context to ensure thread-safe access + upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + if upload_file: + blob = storage.load_once(upload_file.key) + file_query = base64.b64encode(blob).decode() + file_query_dict = { + "content": file_query, + "content_type": DocType.IMAGE, + } + rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) - if rerank_document.metadata is not None: - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + return rerank_result, unique_documents + else: + raise ValueError(f"Upload file not found for query: {query}") - rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) - return rerank_documents[:top_n] if top_n else rerank_documents + else: + raise ValueError(f"Query type {query_type} is not supported") diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index c455db6095..18020608cb 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -7,6 +7,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - unique_documents.append(document) + # weight rerank only support text documents + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) else: if document not in unique_documents: unique_documents.append(document) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3db67efb0e..ec55d2d0cc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -8,6 +8,7 @@ from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import and_, or_, select +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.file import File, FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -99,7 +105,8 @@ class DatasetRetrieval: message_id: str, memory: TokenBufferMemory | None = None, inputs: Mapping[str, Any] | None = None, - ) -> str | None: + vision_enabled: bool = False, + ) -> tuple[str | None, list[File] | None]: """ Retrieve dataset. :param app_id: app_id @@ -118,7 +125,7 @@ class DatasetRetrieval: """ dataset_ids = config.dataset_ids if len(dataset_ids) == 0: - return None + return None, [] retrieve_config = config.retrieve_config # check model is support tool calling @@ -136,7 +143,7 @@ class DatasetRetrieval: ) if not model_schema: - return None + return None, [] planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features @@ -182,8 +189,8 @@ class DatasetRetrieval: tenant_id, user_id, user_from, - available_datasets, query, + available_datasets, model_instance, model_config, planning_strategy, @@ -213,6 +220,7 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] document_context_list: list[DocumentContext] = [] + context_files: list[File] = [] retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: @@ -248,6 +256,31 @@ class DatasetRetrieval: score=record.score, ) ) + if vision_enabled: + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment.id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attchment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=segment.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attchment_info) if show_retrieve_source: for record in records: segment = record.segment @@ -288,8 +321,10 @@ class DatasetRetrieval: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" + return str( + "\n".join([document_context.content for document_context in document_context_list]) + ), context_files + return "", context_files def single_retrieve( self, @@ -297,8 +332,8 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, query: str, + available_datasets: list, model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -336,7 +371,7 @@ class DatasetRetrieval: dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) self._record_usage(router_usage) - + timer = None if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -406,10 +441,19 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrieval_end(results, message_id, timer) + thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": results, + "message_id": message_id, + "timer": timer, + }, + ) + thread.start() return results return [] @@ -421,7 +465,7 @@ class DatasetRetrieval: user_id: str, user_from: str, available_datasets: list, - query: str, + query: str | None, top_k: int, score_threshold: float, reranking_mode: str, @@ -431,10 +475,11 @@ class DatasetRetrieval: message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): if not available_datasets: return [] - threads = [] + all_threads = [] all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( @@ -467,131 +512,226 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model - - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: - continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - with measure_time() as timer: - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - - all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + if query: + query_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": query, + "attachment_id": None, + }, ) - else: - if index_type == "economy": - all_documents = self.calculate_keyword_score(query, all_documents, top_k) - elif index_type == "high_quality": - all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) - else: - all_documents = all_documents[:top_k] if top_k else all_documents - - self._on_query(query, dataset_ids, app_id, user_from, user_id) + all_threads.append(query_thread) + query_thread.start() + if attachment_ids: + for attachment_id in attachment_ids: + attachment_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": None, + "attachment_id": attachment_id, + }, + ) + all_threads.append(attachment_thread) + attachment_thread.start() + for thread in all_threads: + thread.join() + self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrieval_end(all_documents, message_id, timer) - - return all_documents - - def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): - """Handle retrieval end.""" - dify_documents = [document for document in documents if document.provider == "dify"] - for document in dify_documents: - if document.metadata is not None: - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == document.metadata["document_id"] - ) - dataset_document = db.session.scalar(dataset_document_stmt) - if dataset_document: - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = db.session.scalar(child_chunk_stmt) - if child_chunk: - _ = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) - ) - else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) - - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) - - db.session.commit() - - # get tracing instance - trace_manager: TraceQueueManager | None = ( - self.application_generate_entity.trace_manager if self.application_generate_entity else None - ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer - ) + # add thread to call _on_retrieval_end + retrieval_end_thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": all_documents, + "message_id": message_id, + "timer": timer, + }, ) + retrieval_end_thread.start() + retrieval_resource_list = [] + doc_ids_filter = [] + for document in all_documents: + if document.provider == "dify": + doc_id = document.metadata.get("doc_id") + if doc_id and doc_id not in doc_ids_filter: + doc_ids_filter.append(doc_id) + retrieval_resource_list.append(document) + elif document.provider == "external": + retrieval_resource_list.append(document) + return retrieval_resource_list - def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str): + def _on_retrieval_end( + self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + ): + """Handle retrieval end.""" + with flask_app.app_context(): + dify_documents = [document for document in documents if document.provider == "dify"] + segment_ids = [] + segment_index_node_ids = [] + with Session(db.engine) as session: + for document in dify_documents: + if document.metadata is not None: + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == document.metadata["document_id"] + ) + dataset_document = session.scalar(dataset_document_stmt) + if dataset_document: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + segment_id = None + if ( + "doc_type" not in document.metadata + or document.metadata.get("doc_type") == DocType.TEXT + ): + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + child_chunk = session.scalar(child_chunk_stmt) + if child_chunk: + segment_id = child_chunk.segment_id + elif ( + "doc_type" in document.metadata + and document.metadata.get("doc_type") == DocType.IMAGE + ): + attachment_info_dict = RetrievalService.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + segment_id = attachment_info_dict["segment_id"] + if segment_id: + if segment_id not in segment_ids: + segment_ids.append(segment_id) + _ = ( + session.query(DocumentSegment) + .where(DocumentSegment.id == segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + ) + else: + query = None + if ( + "doc_type" not in document.metadata + or document.metadata.get("doc_type") == DocType.TEXT + ): + if document.metadata["doc_id"] not in segment_index_node_ids: + segment = ( + session.query(DocumentSegment) + .where(DocumentSegment.index_node_id == document.metadata["doc_id"]) + .first() + ) + if segment: + segment_index_node_ids.append(document.metadata["doc_id"]) + segment_ids.append(segment.id) + query = session.query(DocumentSegment).where( + DocumentSegment.id == segment.id + ) + elif ( + "doc_type" in document.metadata + and document.metadata.get("doc_type") == DocType.IMAGE + ): + attachment_info_dict = RetrievalService.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + segment_id = attachment_info_dict["segment_id"] + if segment_id not in segment_ids: + segment_ids.append(segment_id) + query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id) + if query: + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.where( + DocumentSegment.dataset_id == document.metadata["dataset_id"] + ) + + # add hit count to document segment + query.update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + + db.session.commit() + + # get tracing instance + trace_manager: TraceQueueManager | None = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer + ) + ) + + def _on_query( + self, + query: str | None, + attachment_ids: list[str] | None, + dataset_ids: list[str], + app_id: str, + user_from: str, + user_id: str, + ): """ Handle query. """ - if not query: + if not query and not attachment_ids: return dataset_queries = [] for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source="app", - source_app_id=app_id, - created_by_role=user_from, - created_by=user_id, - ) - dataset_queries.append(dataset_query) - if dataset_queries: - db.session.add_all(dataset_queries) + contents = [] + if query: + contents.append({"content_type": QueryType.TEXT_QUERY, "content": query}) + if attachment_ids: + for attachment_id in attachment_ids: + contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}) + if contents: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=json.dumps(contents), + source="app", + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id, + ) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever( @@ -603,6 +743,7 @@ class DatasetRetrieval: all_documents: list, document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -611,7 +752,7 @@ class DatasetRetrieval: if not dataset: return [] - if dataset.provider == "external": + if dataset.provider == "external" and query: external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset_id, @@ -663,6 +804,7 @@ class DatasetRetrieval: reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, + attachment_ids=attachment_ids, ) all_documents.extend(documents) @@ -1222,3 +1364,86 @@ class DatasetRetrieval: usage = LLMUsage.empty_usage() return full_text, usage + + def _multiple_retrieve_thread( + self, + flask_app: Flask, + available_datasets: list, + metadata_condition: MetadataCondition | None, + metadata_filter_document_ids: dict[str, list[str]] | None, + all_documents: list[Document], + tenant_id: str, + reranking_enable: bool, + reranking_mode: str, + reranking_model: dict | None, + weights: dict[str, Any] | None, + top_k: int, + score_threshold: float, + query: str | None, + attachment_id: str | None, + ): + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: + continue + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + + if reranking_enable: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) + else: + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json new file mode 100644 index 0000000000..1a07869662 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json @@ -0,0 +1,65 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "array", + "title": "Multimodal General Structure", + "description": "Schema for multimodal general structure (v1) - array of objects", + "properties": { + "general_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "description": "List of files" + } + } + }, + "required": ["content"] + }, + "description": "List of content and files" + } + } +} \ No newline at end of file diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json new file mode 100644 index 0000000000..4ffb590519 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json @@ -0,0 +1,78 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "object", + "title": "Multimodal Parent-Child Structure", + "description": "Schema for multimodal parent-child structure (v1)", + "properties": { + "parent_mode": { + "type": "string", + "description": "The mode of parent-child relationship" + }, + "parent_child_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "parent_content": { + "type": "string", + "description": "The parent content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"] + }, + "description": "List of files" + }, + "child_contents": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of child contents" + } + }, + "required": ["parent_content", "child_contents"] + }, + "description": "List of parent-child chunk pairs" + } + }, + "required": ["parent_mode", "parent_child_chunks"] +} \ No newline at end of file diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 5cdf473542..fef3157f27 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" +def sign_upload_file(upload_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url for plugin access + """ + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 105823f896..80c69e94c8 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str: """ # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+" return re.sub(pattern, "", text) diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index ebf93f2fc2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -3,6 +3,7 @@ from datetime import datetime from pydantic import Field +from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.pause_reason import PauseReason @@ -14,6 +15,7 @@ from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") + context_files: list[File] | None = Field(default=None, description="context files") class ModelInvokeCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 5a7db6e0e6..e323533835 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from email.message import Message from typing import Any, Literal +import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -96,10 +97,12 @@ class HttpRequestNodeData(BaseNodeData): class Response: headers: dict[str, str] response: httpx.Response + _cached_text: str | None def __init__(self, response: httpx.Response): self.response = response self.headers = dict(response.headers) + self._cached_text = None @property def is_file(self): @@ -159,7 +162,31 @@ class Response: @property def text(self) -> str: - return self.response.text + """ + Get response text with robust encoding detection. + + Uses charset_normalizer for better encoding detection than httpx's default, + which helps handle Chinese and other non-ASCII characters properly. + """ + # Check cache first + if hasattr(self, "_cached_text") and self._cached_text is not None: + return self._cached_text + + # Try charset_normalizer for robust encoding detection first + detected_encoding = charset_normalizer.from_bytes(self.response.content).best() + if detected_encoding and detected_encoding.encoding: + try: + text = self.response.content.decode(detected_encoding.encoding) + self._cached_text = text + return text + except (UnicodeDecodeError, TypeError, LookupError): + # Fallback to httpx's encoding detection if charset_normalizer fails + pass + + # Fallback to httpx's built-in encoding detection + text = self.response.text + self._cached_text = text + return text @property def content(self) -> bytes: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8aa6a5016f..86bb2495e7 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ type: str = "knowledge-retrieval" - query_variable_selector: list[str] + query_variable_selector: list[str] | None | str = None + query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: MultipleRetrievalConfig | None = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1b57d23e24..adc474bd60 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import ( + ArrayFileSegment, + FileSegment, StringSegment, ) from core.variables.segments import ArrayObjectSegment @@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return "1" def _run(self) -> NodeRunResult: - # extract variables - variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) - if not isinstance(variable, StringSegment): + if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, - error="Query variable is not string type.", - ) - query = variable.value - variables = {"query": query} - if not query: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." + process_data={}, + outputs={}, + metadata={}, + llm_usage=LLMUsage.empty_usage(), ) + variables: dict[str, Any] = {} + # extract variables + if self._node_data.query_variable_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) + if not isinstance(variable, StringSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value + variables["query"] = query + + if self._node_data.query_attachment_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector) + if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Attachments variable is not array file or file type.", + ) + if isinstance(variable, ArrayFileSegment): + variables["attachments"] = variable.value + else: + variables["attachments"] = [variable.value] + # TODO(-LAN-): Move this check outside. # check rate limit knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) @@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # retrieve knowledge usage = LLMUsage.empty_usage() try: - results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD db.session.close() def _fetch_dataset_retriever( - self, node_data: KnowledgeRetrievalNodeData, query: str + self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[dict[str, Any]], LLMUsage]: usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids - + query = variables.get("query") + attachments = variables.get("attachments") + metadata_filter_document_ids = None + metadata_condition = None + metadata_usage = LLMUsage.empty_usage() # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) @@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( - [dataset.id for dataset in available_datasets], query, node_data - ) - usage = self._merge_usage(usage, metadata_usage) + if query: + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( + [dataset.id for dataset in available_datasets], query, node_data + ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, + attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) usage = self._merge_usage(usage, dataset_retrieval.llm_usage) @@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = [] # deal with external documents for item in external_documents: - source = { + source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = { "metadata": { "_source": "knowledge", "dataset_id": item.metadata.get("dataset_id"), @@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD "doc_metadata": document.doc_metadata, }, "title": document.name, + "files": list(record.files) if record.files else None, } if segment.answer: source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" @@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + key=self._score, # type: ignore[arg-type, return-value] reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): - item["metadata"]["position"] = position + item["metadata"]["position"] = position # type: ignore[index] return retrieval_resource_list, usage + def _score(self, item: dict[str, Any]) -> float: + meta = item.get("metadata") + if isinstance(meta, dict): + s = meta.get("score") + if isinstance(s, (int, float)): + return float(s) + return 0.0 + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: @@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) variable_mapping = {} - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1a2473e0bb..10682ae38a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,8 +7,10 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from sqlalchemy import select + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file import FileType, file_manager +from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.signature import sign_upload_file from core.variables import ( ArrayFileSegment, ArraySegment, @@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from extensions.ext_database import db +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile from . import llm_utils from .entities import ( @@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]): # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None + context_files: list[File] = [] for event in generator: context = event.context + context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context + if context_files: + node_inputs["#context_files#"] = [file.model_dump() for file in context_files] + # fetch model config model_instance, model_config = LLMNode._fetch_model_config( node_data_model=self.node_data.model, @@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, + context_files=context_files, ) # handle invoke result @@ -654,10 +666,13 @@ class LLMNode(Node[LLMNodeData]): context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + yield RunRetrieverResourceEvent( + retriever_resources=[], context=context_value_variable.value, context_files=[] + ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] + context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -670,9 +685,34 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) - + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attchment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attchment_info) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, context=context_str.strip() + retriever_resources=original_retriever_resource, + context=context_str.strip(), + context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: @@ -700,6 +740,7 @@ class LLMNode(Node[LLMNodeData]): content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), + files=context_dict.get("files"), ) return source @@ -741,6 +782,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, + context_files: list["File"] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -853,6 +895,23 @@ class LLMNode(Node[LLMNodeData]): else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 89c4d8fba9..1e5ec7d200 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -97,11 +97,27 @@ dataset_detail_fields = { "total_documents": fields.Integer, "total_available_documents": fields.Integer, "enable_api": fields.Boolean, + "is_multimodal": fields.Boolean, +} + +file_info_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + +content_fields = { + "content_type": fields.String, + "content": fields.String, + "file_info": fields.Nested(file_info_fields, allow_null=True), } dataset_query_detail_fields = { "id": fields.String, - "content": fields.String, + "queries": fields.Nested(content_fields), "source": fields.String, "source_app_id": fields.String, "created_by_role": fields.String, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index c12ebc09c8..a707500445 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -9,6 +9,8 @@ upload_config_fields = { "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, "workflow_file_upload_limit": fields.Integer, + "image_file_batch_limit": fields.Integer, + "single_chunk_attachment_limit": fields.Integer, } diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 75bdff1803..e70f9fa722 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -43,9 +43,19 @@ child_chunk_fields = { "score": fields.Float, } +files_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, + "files": fields.List(fields.Nested(files_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2ff917d6bc..56d6b68378 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -13,6 +13,15 @@ child_chunk_fields = { "updated_at": TimestampField, } +attachment_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -39,4 +48,5 @@ segment_fields = { "error": fields.String, "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), + "attachments": fields.List(fields.Nested(attachment_fields)), } diff --git a/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py new file mode 100644 index 0000000000..187bf7136d --- /dev/null +++ b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py @@ -0,0 +1,57 @@ +"""support-multi-modal + +Revision ID: d57accd375ae +Revises: 03f8dcbc611e +Create Date: 2025-11-12 15:37:12.363670 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd57accd375ae' +down_revision = '7bb281b7a422' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('segment_attachment_bindings', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('attachment_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey') + ) + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.create_index( + 'segment_attachment_binding_tenant_dataset_document_segment_idx', + ['tenant_id', 'dataset_id', 'document_id', 'segment_id'], + unique=False + ) + batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('is_multimodal') + + + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.drop_index('segment_attachment_binding_attachment_idx') + batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx') + + op.drop_table('segment_attachment_bindings') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index e072711b82..5bbf44050c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -76,6 +78,7 @@ class Dataset(Base): pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) @property def total_documents(self): @@ -728,9 +731,7 @@ class DocumentSegment(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() - ) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(LongText, nullable=True) @@ -866,6 +867,47 @@ class DocumentSegment(Base): return text + @property + def attachments(self) -> list[dict[str, Any]]: + # Use JOIN to fetch attachments in a single query instead of two separate queries + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == self.tenant_id, + SegmentAttachmentBinding.dataset_id == self.dataset_id, + SegmentAttachmentBinding.document_id == self.document_id, + SegmentAttachmentBinding.segment_id == self.id, + ) + ).all() + if not attachments_with_bindings: + return [] + attachment_list = [] + for _, attachment in attachments_with_bindings: + upload_file_id = attachment.id + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + reference_url = dify_config.CONSOLE_API_URL or "" + base_url = f"{reference_url}/files/{upload_file_id}/image-preview" + source_url = f"{base_url}?{params}" + attachment_list.append( + { + "id": attachment.id, + "name": attachment.name, + "size": attachment.size, + "extension": attachment.extension, + "mime_type": attachment.mime_type, + "source_url": source_url, + } + ) + return attachment_list + class ChildChunk(Base): __tablename__ = "child_chunks" @@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase): DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) + @property + def queries(self) -> list[dict[str, Any]]: + try: + queries = json.loads(self.content) + if isinstance(queries, list): + for query in queries: + if query["content_type"] == QueryType.IMAGE_QUERY: + file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + if file_info: + query["file_info"] = { + "id": file_info.id, + "name": file_info.name, + "size": file_info.size, + "extension": file_info.extension, + "mime_type": file_info.mime_type, + "source_url": sign_upload_file(file_info.id, file_info.extension), + } + else: + query["file_info"] = None + + return queries + else: + return [queries] + except JSONDecodeError: + return [ + { + "content_type": QueryType.TEXT_QUERY, + "content": self.content, + "file_info": None, + } + ] + class DatasetKeywordTable(TypeBase): __tablename__ = "dataset_keyword_tables" @@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase): onupdate=func.current_timestamp(), init=False, ) + + +class SegmentAttachmentBinding(Base): + __tablename__ = "segment_attachment_bindings" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), + sa.Index( + "segment_attachment_binding_tenant_dataset_document_segment_idx", + "tenant_id", + "dataset_id", + "document_id", + "segment_id", + ), + sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), + ) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py new file mode 100644 index 0000000000..2bd5627d5e --- /dev/null +++ b/api/services/attachment_service.py @@ -0,0 +1,31 @@ +import base64 + +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from extensions.ext_storage import storage +from models.model import UploadFile + +PREVIEW_WORDS_LIMIT = 3000 + + +class AttachmentService: + _session_maker: sessionmaker + + def __init__(self, session_factory: sessionmaker | Engine | None = None): + if isinstance(session_factory, Engine): + self._session_maker = sessionmaker(bind=session_factory) + elif isinstance(session_factory, sessionmaker): + self._session_maker = session_factory + else: + raise AssertionError("must be a sessionmaker or an Engine.") + + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 208ebcb018..00f06e9405 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Literal, cast import sqlalchemy as sa from redis.exceptions import LockNotOwnedError @@ -19,9 +19,10 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted @@ -46,12 +47,14 @@ from models.dataset import ( DocumentSegment, ExternalKnowledgeBindings, Pipeline, + SegmentAttachmentBinding, ) from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding from models.workflow import Workflow -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -82,7 +85,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_update_task import document_indexing_update_task -from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -363,6 +365,27 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): + try: + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=model, + ) + text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance) + model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema: + raise ValueError("Model schema not found") + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + else: + return False + except LLMBadRequestError: + raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") + @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: @@ -402,13 +425,13 @@ class DatasetService: if not dataset: raise ValueError("Dataset not found") # check if dataset name is exists - - if DatasetService._has_dataset_same_name( - tenant_id=dataset.tenant_id, - dataset_id=dataset_id, - name=data.get("name", dataset.name), - ): - raise ValueError("Dataset name already exists") + if data.get("name") and data.get("name") != dataset.name: + if DatasetService._has_dataset_same_name( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + name=data.get("name", dataset.name), + ): + raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) @@ -844,6 +867,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model or "", ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -880,6 +909,12 @@ class DatasetService: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model.model ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id dataset.indexing_technique = knowledge_configuration.indexing_technique except LLMBadRequestError: @@ -937,6 +972,12 @@ class DatasetService: ) ) dataset.collection_binding_id = dataset_collection_binding.id + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1761,7 +1802,9 @@ class DocumentService: if document_ids: DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + DuplicateDocumentIndexingTaskProxy( + dataset.tenant_id, dataset.id, duplicate_document_ids + ).delay() except LockNotOwnedError: pass @@ -2303,6 +2346,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + is_multimodal=knowledge_config.is_multimodal, ) db.session.add(dataset) @@ -2683,6 +2727,13 @@ class SegmentService: if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") + if args.get("attachment_ids"): + if not isinstance(args["attachment_ids"], list): + raise ValueError("Attachment IDs is invalid") + single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT + if len(args["attachment_ids"]) > single_chunk_attachment_limit: + raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") + @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): assert isinstance(current_user, Account) @@ -2729,11 +2780,23 @@ class SegmentService: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] - db.session.add(segment_document) - # update document word count - assert document.word_count is not None - document.word_count += segment_document.word_count - db.session.add(document) + db.session.add(segment_document) + # update document word count + assert document.word_count is not None + document.word_count += segment_document.word_count + db.session.add(document) + db.session.commit() + + if args["attachment_ids"]: + for attachment_id in args["attachment_ids"]: + binding = SegmentAttachmentBinding( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + segment_id=segment_document.id, + attachment_id=attachment_id, + ) + db.session.add(binding) db.session.commit() # save vector index @@ -2897,7 +2960,7 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance if dataset.indexing_technique == "high_quality": @@ -2924,12 +2987,11 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) @@ -2974,7 +3036,7 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting @@ -3000,15 +3062,15 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) - + # update multimodel vector index + VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: logger.exception("update segment index failed") segment.enabled = False @@ -3046,7 +3108,9 @@ class SegmentService: ) child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids + ) db.session.delete(segment) # update document word count @@ -3095,7 +3159,9 @@ class SegmentService: # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids + ) if document.word_count is None: document.word_count = 0 diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 81e0c0ecd4..eeb14072bd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -29,8 +29,14 @@ def get_current_user(): from models.account import Account from models.model import EndUser - if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore - raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + try: + user_object = current_user._get_current_object() + except AttributeError: + # Handle case where current_user might not be a LocalProxy in test environments + user_object = current_user + + if not isinstance(user_object, (Account, EndUser)): + raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}") return current_user diff --git a/api/services/document_indexing_proxy/__init__.py b/api/services/document_indexing_proxy/__init__.py new file mode 100644 index 0000000000..74195adbe1 --- /dev/null +++ b/api/services/document_indexing_proxy/__init__.py @@ -0,0 +1,11 @@ +from .base import DocumentTaskProxyBase +from .batch_indexing_base import BatchDocumentIndexingProxy +from .document_indexing_task_proxy import DocumentIndexingTaskProxy +from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy + +__all__ = [ + "BatchDocumentIndexingProxy", + "DocumentIndexingTaskProxy", + "DocumentTaskProxyBase", + "DuplicateDocumentIndexingTaskProxy", +] diff --git a/api/services/document_indexing_proxy/base.py b/api/services/document_indexing_proxy/base.py new file mode 100644 index 0000000000..56e47857c9 --- /dev/null +++ b/api/services/document_indexing_proxy/base.py @@ -0,0 +1,111 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import cached_property +from typing import Any, ClassVar + +from enums.cloud_plan import CloudPlan +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) + + +class DocumentTaskProxyBase(ABC): + """ + Base proxy for all document processing tasks. + + Handles common logic: + - Feature/billing checks + - Dispatch routing based on plan + + Subclasses must define: + - QUEUE_NAME: Redis queue identifier + - NORMAL_TASK_FUNC: Task function for normal priority + - PRIORITY_TASK_FUNC: Task function for high priority + """ + + QUEUE_NAME: ClassVar[str] + NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] + PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] + + def __init__(self, tenant_id: str, dataset_id: str): + """ + Initialize with minimal required parameters. + + Args: + tenant_id: Tenant identifier for billing/features + dataset_id: Dataset identifier for logging + """ + self._tenant_id = tenant_id + self._dataset_id = dataset_id + + @cached_property + def features(self): + return FeatureService.get_features(self._tenant_id) + + @abstractmethod + def _send_to_direct_queue(self, task_func: Callable[..., Any]): + """ + Send task directly to Celery queue without tenant isolation. + + Subclasses implement this to pass task-specific parameters. + + Args: + task_func: The Celery task function to call + """ + pass + + @abstractmethod + def _send_to_tenant_queue(self, task_func: Callable[..., Any]): + """ + Send task to tenant-isolated queue. + + Subclasses implement this to handle queue management. + + Args: + task_func: The Celery task function to call + """ + pass + + def _send_to_default_tenant_queue(self): + """Route to normal priority with tenant isolation.""" + self._send_to_tenant_queue(self.NORMAL_TASK_FUNC) + + def _send_to_priority_tenant_queue(self): + """Route to priority queue with tenant isolation.""" + self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC) + + def _send_to_priority_direct_queue(self): + """Route to priority queue without tenant isolation.""" + self._send_to_direct_queue(self.PRIORITY_TASK_FUNC) + + def _dispatch(self): + """ + Dispatch task based on billing plan. + + Routing logic: + - Sandbox plan → normal queue + tenant isolation + - Paid plans → priority queue + tenant isolation + - Self-hosted → priority queue, no isolation + """ + logger.info( + "dispatch args: %s - %s - %s", + self._tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + # dispatch to different indexing queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan + self._send_to_default_tenant_queue() + else: + # dispatch to priority pipeline queue with tenant self sub queue for other plans + self._send_to_priority_tenant_queue() + else: + # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue() + + def delay(self): + """Public API: Queue the task asynchronously.""" + self._dispatch() diff --git a/api/services/document_indexing_proxy/batch_indexing_base.py b/api/services/document_indexing_proxy/batch_indexing_base.py new file mode 100644 index 0000000000..dd122f34a8 --- /dev/null +++ b/api/services/document_indexing_proxy/batch_indexing_base.py @@ -0,0 +1,76 @@ +import logging +from collections.abc import Callable, Sequence +from dataclasses import asdict +from typing import Any + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue + +from .base import DocumentTaskProxyBase + +logger = logging.getLogger(__name__) + + +class BatchDocumentIndexingProxy(DocumentTaskProxyBase): + """ + Base proxy for batch document indexing tasks (document_ids in plural). + + Adds: + - Tenant isolated queue management + - Batch document handling + """ + + def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Initialize with batch documents. + + Args: + tenant_id: Tenant identifier + dataset_id: Dataset identifier + document_ids: List of document IDs to process + """ + super().__init__(tenant_id, dataset_id) + self._document_ids = document_ids + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME) + + def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to direct queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids) + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + + def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to tenant-isolated queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info( + "tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME + ) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks( + [ + asdict( + DocumentTask( + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + ) + ] + ) + logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) diff --git a/api/services/document_indexing_proxy/document_indexing_task_proxy.py b/api/services/document_indexing_proxy/document_indexing_task_proxy.py new file mode 100644 index 0000000000..fce79a8387 --- /dev/null +++ b/api/services/document_indexing_proxy/document_indexing_task_proxy.py @@ -0,0 +1,12 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task + + +class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "document_indexing" + NORMAL_TASK_FUNC = normal_document_indexing_task + PRIORITY_TASK_FUNC = priority_document_indexing_task diff --git a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..277cfbdcf1 --- /dev/null +++ b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py @@ -0,0 +1,15 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.duplicate_document_indexing_task import ( + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + + +class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for duplicate document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing" + NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task + PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task diff --git a/api/services/document_indexing_task_proxy.py b/api/services/document_indexing_task_proxy.py deleted file mode 100644 index 861c84b586..0000000000 --- a/api/services/document_indexing_task_proxy.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from collections.abc import Callable, Sequence -from dataclasses import asdict -from functools import cached_property - -from core.entities.document_task import DocumentTask -from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureService -from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task - -logger = logging.getLogger(__name__) - - -class DocumentIndexingTaskProxy: - def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): - self._tenant_id = tenant_id - self._dataset_id = dataset_id - self._document_ids = document_ids - self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") - - @cached_property - def features(self): - return FeatureService.get_features(self._tenant_id) - - def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to direct queue", self._dataset_id) - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - - def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to tenant queue", self._dataset_id) - if self._tenant_isolated_task_queue.get_task_key(): - # Add to waiting queue using List operations (lpush) - self._tenant_isolated_task_queue.push_tasks( - [ - asdict( - DocumentTask( - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - ) - ] - ) - logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids) - else: - # Set flag and execute task - self._tenant_isolated_task_queue.set_task_waiting_time() - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids) - - def _send_to_default_tenant_queue(self): - self._send_to_tenant_queue(normal_document_indexing_task) - - def _send_to_priority_tenant_queue(self): - self._send_to_tenant_queue(priority_document_indexing_task) - - def _send_to_priority_direct_queue(self): - self._send_to_direct_queue(priority_document_indexing_task) - - def _dispatch(self): - logger.info( - "dispatch args: %s - %s - %s", - self._tenant_id, - self.features.billing.enabled, - self.features.billing.subscription.plan, - ) - # dispatch to different indexing queue with tenant isolation when billing enabled - if self.features.billing.enabled: - if self.features.billing.subscription.plan == CloudPlan.SANDBOX: - # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan - self._send_to_default_tenant_queue() - else: - # dispatch to priority pipeline queue with tenant self sub queue for other plans - self._send_to_priority_tenant_queue() - else: - # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise - self._send_to_priority_direct_queue() - - def delay(self): - self._dispatch() diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 131e90e195..7959734e89 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None name: str | None = None + is_multimodal: bool = False + + +class SegmentCreateArgs(BaseModel): + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdateArgs(BaseModel): @@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False enabled: bool | None = None + attachment_ids: list[str] | None = None class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 27936f6278..40faa85b9a 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -324,4 +324,5 @@ class ExternalDatasetService: ) if response.status_code == 200: return cast(list[Any], response.json().get("records", [])) - return [] + else: + raise ValueError(response.text) diff --git a/api/services/file_service.py b/api/services/file_service.py index 1980cd8d59..0911cf38c4 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,3 +1,4 @@ +import base64 import hashlib import os import uuid @@ -123,6 +124,15 @@ class FileService: return file_size <= file_size_limit + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index dfb49cf2bd..8e8e78f83f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,3 +1,4 @@ +import json import logging import time from typing import Any @@ -5,6 +6,7 @@ from typing import Any from core.app.app_config.entities import ModelConfig from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -32,6 +34,7 @@ class HitTestingService: account: Account, retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, + attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() @@ -41,7 +44,7 @@ class HitTestingService: retrieval_model = dataset.retrieval_model or default_retrieval_model document_ids_filter = None metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions: + if metadata_filtering_conditions and query: dataset_retrieval = DatasetRetrieval() from core.app.app_config.entities import MetadataFilteringCondition @@ -66,6 +69,7 @@ class HitTestingService: retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), dataset_id=dataset.id, query=query, + attachment_ids=attachment_ids, top_k=retrieval_model.get("top_k", 4), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] @@ -80,17 +84,24 @@ class HitTestingService: end = time.perf_counter() logger.debug("Hit testing retrieve in %s seconds", end - start) - - dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source="hit_testing", - source_app_id=None, - created_by_role="account", - created_by=account.id, - ) - - db.session.add(dataset_query) + dataset_queries = [] + if query: + content = {"content_type": QueryType.TEXT_QUERY, "content": query} + dataset_queries.append(content) + if attachment_ids: + for attachment_id in attachment_ids: + content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id} + dataset_queries.append(content) + if dataset_queries: + dataset_query = DatasetQuery( + dataset_id=dataset.id, + content=json.dumps(dataset_queries), + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, + ) + db.session.add(dataset_query) db.session.commit() return cls.compact_retrieve_response(query, all_documents) @@ -168,9 +179,14 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): query = args["query"] + attachment_ids = args["attachment_ids"] - if not query or len(query) > 250: - raise ValueError("Query is required and cannot exceed 250 characters") + if not attachment_ids and not query: + raise ValueError("Query or attachment_ids is required") + if query and len(query) > 250: + raise ValueError("Query cannot exceed 250 characters") + if attachment_ids and not isinstance(attachment_ids, list): + raise ValueError("Attachment_ids must be a list") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/rag_pipeline/rag_pipeline_task_proxy.py b/api/services/rag_pipeline/rag_pipeline_task_proxy.py index 94dd7941da..1a7b104a70 100644 --- a/api/services/rag_pipeline/rag_pipeline_task_proxy.py +++ b/api/services/rag_pipeline/rag_pipeline_task_proxy.py @@ -38,21 +38,24 @@ class RagPipelineTaskProxy: upload_file = FileService(db.engine).upload_text( json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id ) + logger.info( + "tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities) + ) return upload_file.id def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to direct queue", upload_file_id) + logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id) task_func.delay( # type: ignore rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to tenant queue", upload_file_id) + logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id) if self._tenant_isolated_task_queue.get_task_key(): # Add to waiting queue using List operations (lpush) self._tenant_isolated_task_queue.push_tasks([upload_file_id]) - logger.info("push tasks: %s", upload_file_id) + logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id) else: # Set flag and execute task self._tenant_isolated_task_queue.set_task_waiting_time() @@ -60,7 +63,7 @@ class RagPipelineTaskProxy: rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) - logger.info("init tasks: %s", upload_file_id) + logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id) def _send_to_default_tenant_queue(self, upload_file_id: str): self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index abc92a0181..f1fa33cb75 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -21,9 +24,10 @@ class VectorService: cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] + multimodal_documents: list[AttachmentDocument] = [] for segment in segments: - if doc_form == IndexType.PARENT_CHILD_INDEX: + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() if not dataset_document: logger.warning( @@ -70,12 +74,29 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) documents.append(rag_document) + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_document: AttachmentDocument = AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + multimodal_documents.append(multimodal_document) + index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor() + if len(documents) > 0: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list) + if len(multimodal_documents) > 0: + index_processor.load(dataset, [], multimodal_documents, with_keywords=False) @classmethod def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): @@ -130,6 +151,7 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) # use full doc mode to generate segment's child chunk @@ -226,3 +248,92 @@ class VectorService: def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) vector.delete_by_ids([child_chunk.index_node_id]) + + @classmethod + def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): + if dataset.indexing_technique != "high_quality": + return + + attachments = segment.attachments + old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else [] + + # Check if there's any actual change needed + if set(attachment_ids) == set(old_attachment_ids): + return + + try: + vector = Vector(dataset=dataset) + if dataset.is_multimodal: + # Delete old vectors if they exist + if old_attachment_ids: + vector.delete_by_ids(old_attachment_ids) + + # Delete existing segment attachment bindings in one operation + db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( + synchronize_session=False + ) + + if not attachment_ids: + db.session.commit() + return + + # Bulk fetch upload files - only fetch needed fields + upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + + if not upload_file_list: + db.session.commit() + return + + # Create a mapping for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list} + + # Prepare batch operations + bindings = [] + documents = [] + + # Create common metadata base to avoid repetition + base_metadata = { + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + } + + # Process attachments in the order specified by attachment_ids + for attachment_id in attachment_ids: + upload_file = upload_file_map.get(attachment_id) + if not upload_file: + logger.warning("Upload file not found for attachment_id: %s", attachment_id) + continue + + # Create segment attachment binding + bindings.append( + SegmentAttachmentBinding( + tenant_id=segment.tenant_id, + dataset_id=segment.dataset_id, + document_id=segment.document_id, + segment_id=segment.id, + attachment_id=upload_file.id, + ) + ) + + # Create document for vector indexing + documents.append( + Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id}) + ) + + # Bulk insert all bindings at once + if bindings: + db.session.add_all(bindings) + + # Add documents to vector store if any + if documents and dataset.is_multimodal: + vector.add_texts(documents, duplicate_check=True) + + # Single commit for all operations + db.session.commit() + + except Exception: + logger.exception("Failed to update multimodal vector for segment %s", segment.id) + db.session.rollback() + raise diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 933ad6b9e2..e7dead8a56 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str): ) documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 5f2a355d16..8608df6b8e 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -18,6 +18,7 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + SegmentAttachmentBinding, ) from models.model import UploadFile @@ -58,14 +59,20 @@ def clean_dataset_task( ) documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) + ).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexType + from core.rag.index_processor.constant.index_type import IndexStructureType - doc_form = IndexType.PARAGRAPH_INDEX + doc_form = IndexStructureType.PARAGRAPH_INDEX logger.info( click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") ) @@ -90,6 +97,7 @@ def clean_dataset_task( for document in documents: db.session.delete(document) + # delete document file for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) @@ -107,6 +115,19 @@ def clean_dataset_task( ) db.session.delete(image_file) db.session.delete(segment) + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 62200715cc..6d2feb1da3 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage -from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment +from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile logger = logging.getLogger(__name__) @@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i raise Exception("Document has no dataset") segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.exception("Delete file failed when document deleted, file_id: %s", file_id) db.session.delete(file) db.session.commit() + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) # delete dataset metadata binding db.session.query(DatasetMetadataBinding).where( diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 713f149c38..3d13afdec0 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task # type: ignore -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": dataset_documents = ( @@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index dc6ef6fb61..1c7de3b1ce 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,14 +1,14 @@ import logging import time -from typing import Literal import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): +def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index :param dataset_id: dataset_id @@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index e8cbd0f250..bea5c952cf 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,14 +6,15 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, SegmentAttachmentBinding +from models.model import UploadFile logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_segment_from_index_task( - index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None + index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None ): """ Async Remove segment from index @@ -49,6 +50,21 @@ def delete_segment_from_index_task( delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, ) + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + db.session.delete(binding) + # delete upload file + db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + db.session.commit() end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 9038dc179b..c2a3de29f4 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -8,7 +8,7 @@ from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument logger = logging.getLogger(__name__) @@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen try: index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) end_at = time.perf_counter() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index fee4430612..acbdab631b 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue( try: _document_indexing(dataset_id, document_ids) except Exception: - logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id) + logger.exception( + "Error processing document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) finally: tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") @@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue( # Use rpop to get the next task from the queue (FIFO order) next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks) + logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) if next_tasks: for next_task in next_tasks: diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6492e356a3..4078c8910e 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -1,13 +1,16 @@ import logging import time +from collections.abc import Callable, Sequence import click from celery import shared_task from sqlalchemy import select from configs import dify_config +from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): :param dataset_id: :param document_ids: + .. warning:: TO BE DEPRECATED + This function will be deprecated and removed in a future version. + Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead. + Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids) """ + logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids) + _duplicate_document_indexing_task(dataset_id, document_ids) + + +def _duplicate_document_indexing_task_with_tenant_queue( + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] +): + try: + _duplicate_document_indexing_task(dataset_id, document_ids) + except Exception: + logger.exception( + "Error processing duplicate document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) + finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + + logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) + + if next_tasks: + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=document_task.tenant_id, + dataset_id=document_task.dataset_id, + document_ids=document_task.document_ids, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() + + +def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]): documents = [] start_at = time.perf_counter() @@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) finally: db.session.close() + + +@shared_task(queue="dataset") +def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +@shared_task(queue="priority_dataset") +def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 07c44f333e..7615469ed0 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str): ) child_documents.append(child_document) document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + # save vector index - index_processor.load(dataset, [document]) + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index c5ca7a6171..9f17d09e18 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,9 +5,10 @@ import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i try: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i ) child_documents.append(child_document) document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a7f61d9811..1eef361a92 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 92f1dfb73d..275f5abe6e 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py new file mode 100644 index 0000000000..e55c12e678 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py @@ -0,0 +1,244 @@ +"""Integration tests for Trigger Provider subscription permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.workspace import trigger_providers as trigger_providers_api +from libs.datetime_utils import naive_utc_now +from models import Tenant +from models.account import Account, TenantAccountJoin, TenantAccountRole + + +class TestTriggerProviderSubscriptionPermissions: + """Test permission verification for Trigger Provider subscription endpoints.""" + + @pytest.fixture + def mock_account(self, monkeypatch: pytest.MonkeyPatch): + """Create a mock Account for testing.""" + + account = Account(name="Test User", email="test@example.com") + account.id = str(uuid.uuid4()) + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant(name="Test Tenant") + tenant.id = str(uuid.uuid4()) + + mock_session_instance = mock.Mock() + + mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) + monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) + + mock_scalars_result = mock.Mock() + mock_scalars_result.one.return_value = tenant + monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_session_instance + monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) + + account.current_tenant = tenant + account.current_tenant_id = tenant.id + return account + + @pytest.mark.parametrize( + ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), + [ + # Admin/Owner can do everything + (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), + (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), + # Editor can list, get, update (parameters), but not create, build, or delete + (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), + # Normal user cannot do anything + (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), + # Dataset operator cannot do anything + (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), + ], + ) + def test_trigger_subscription_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + list_status: int, + get_status: int, + update_status: int, + create_status: int, + build_status: int, + delete_status: int, + ): + """Test that different roles have appropriate permissions for trigger subscription operations.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + subscription_id = str(uuid.uuid4()) + + # Mock service methods + mock_list_subscriptions = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", + mock_list_subscriptions, + ) + + mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + mock_get_subscription_builder, + ) + + mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + mock_update_subscription_builder, + ) + + mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + mock_create_subscription_builder, + ) + + mock_update_and_build_builder = mock.Mock() + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", + mock_update_and_build_builder, + ) + + mock_delete_provider = mock.Mock() + mock_delete_plugin_trigger = mock.Mock() + mock_db_session = mock.Mock() + mock_db_session.commit = mock.Mock() + + def mock_session_func(engine=None): + return mock_session_context + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_db_session + mock_session_context.__exit__.return_value = None + + monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) + monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) + + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", + mock_delete_provider, + ) + monkeypatch.setattr( + "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", + mock_delete_plugin_trigger, + ) + + # Test 1: List subscriptions (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", + headers=auth_header, + ) + assert response.status_code == list_status + + # Test 2: Get subscription builder (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == get_status + + # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", + headers=auth_header, + json={"parameters": {"webhook_url": "https://example.com/webhook"}}, + ) + assert response.status_code == update_status + + # Test 4: Create subscription builder (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", + headers=auth_header, + json={"credential_type": "api_key"}, + ) + assert response.status_code == create_status + + # Test 5: Build/activate subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", + headers=auth_header, + json={"name": "Test Subscription"}, + ) + assert response.status_code == build_status + + # Test 6: Delete subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", + headers=auth_header, + ) + assert response.status_code == delete_status + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + # Editor should be able to access logs for debugging + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_trigger_subscription_logs_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that different roles have appropriate permissions for accessing subscription logs.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + + # Mock service method + mock_list_logs = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", + mock_list_logs, + ) + + # Test access to logs + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == status diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 9478bb9ddb..088d6ba6ba 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify the expected outcomes # Verify index processor was called correctly - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes @@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() @@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask: # Act: Execute the task add_document_to_index_task(document.id) - # Assert: Verify index processing occurred with all completed segments - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + # Assert: Verify index processing occurred but with empty documents list + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments @@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask: assert len(remaining_logs) == 0 # Verify index processing occurred normally - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify segments were enabled @@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify only eligible segments were processed - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 94e9b76965..37d886f569 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask: document.updated_at = fake.date_time_this_year() document.doc_type = kwargs.get("doc_type", "text") document.doc_metadata = kwargs.get("doc_metadata", {}) - document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") db_session_with_containers.add(document) @@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask: mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + # Extract segment IDs for the task + segment_ids = [segment.id for segment in segments] + # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None # Task should return None on success @@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent dataset - result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found @@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent document - result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when document not found @@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with disabled document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled @@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with archived document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is archived @@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with incomplete indexing - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when indexing is not completed @@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask: fake = Faker() # Test different document forms - document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + document_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in document_forms: # Create test data for each document form @@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None @@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor to raise an exception mock_processor = MagicMock() @@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - should not raise exception - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without raising exceptions assert result is None # Task should return None even when exceptions occur @@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, []) # Verify the task completed successfully assert result is None @@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask: # Create large number of segments segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..aca4be1ffd --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,763 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, # Core function + _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function + duplicate_document_indexing_task, # Deprecated old interface + normal_duplicate_document_indexing_task, # New normal task + priority_duplicate_document_indexing_task, # New priority task +) + + +class TestDuplicateDocumentIndexingTasks: + """Integration tests for duplicate document indexing tasks using testcontainers. + + This test class covers: + - Core _duplicate_document_indexing_task function + - Deprecated duplicate_document_indexing_task function + - New normal_duplicate_document_indexing_task function + - New priority_duplicate_document_indexing_task function + - Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function + - Document segment cleanup logic + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner, + patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service, + patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory, + ): + # Setup mock indexing runner + mock_runner_instance = MagicMock() + mock_indexing_runner.return_value = mock_runner_instance + + # Setup mock feature service + mock_features = MagicMock() + mock_features.billing.enabled = False + mock_feature_service.get_features.return_value = mock_features + + # Setup mock index processor factory + mock_processor = MagicMock() + mock_processor.clean = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "indexing_runner": mock_indexing_runner, + "indexing_runner_instance": mock_runner_instance, + "feature_service": mock_feature_service, + "features": mock_features, + "index_processor_factory": mock_index_processor_factory, + "index_processor": mock_processor, + } + + def _create_test_dataset_and_documents( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + ): + """ + Helper method to create a test dataset and documents for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(document_count): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def _create_test_dataset_with_segments( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + ): + """ + Helper method to create a test dataset with documents and segments. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + segments_per_doc: Number of segments per document + + Returns: + tuple: (dataset, documents, segments) - Created dataset, documents and segments + """ + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count + ) + + fake = Faker() + segments = [] + + # Create segments for each document + for document in documents: + for i in range(segments_per_doc): + segment = DocumentSegment( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + index_node_id=f"{document.id}-node-{i}", + index_node_hash=fake.sha256(), + content=fake.text(max_nb_chars=200), + word_count=50, + tokens=100, + status="completed", + enabled=True, + indexing_at=fake.date_time_this_year(), + created_by=dataset.created_by, # Add required field + ) + db.session.add(segment) + segments.append(segment) + + db.session.commit() + + # Refresh to ensure all relationships are loaded + for document in documents: + db.session.refresh(document) + + return dataset, documents, segments + + def _create_test_dataset_with_billing_features( + self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ): + """ + Helper method to create a test dataset with billing features configured. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + billing_enabled: Whether billing is enabled + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(3): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Configure billing features + mock_external_service_dependencies["features"].billing.enabled = billing_enabled + if billing_enabled: + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 50 + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def test_duplicate_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful duplicate document indexing with multiple documents. + + This test verifies: + - Proper dataset retrieval from database + - Correct document processing and status updates + - IndexingRunner integration + - Database state updates + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify the expected outcomes + # Verify indexing runner was called correctly + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 3 + + def test_duplicate_document_indexing_task_with_segment_cleanup( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test duplicate document indexing with existing segments that need cleanup. + + This test verifies: + - Old segments are identified and cleaned + - Index processor clean method is called + - Segments are deleted from database + - New indexing proceeds after cleanup + """ + # Arrange: Create test data with existing segments + dataset, documents, segments = self._create_test_dataset_with_segments( + db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify segment cleanup + # Verify index processor clean was called for each document with segments + assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents) + + # Verify segments were deleted from database + # Re-query segments from database since _duplicate_document_indexing_task uses a different session + for segment in segments: + deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + assert deleted_segment is None + + # Verify documents were updated to parsing status + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify indexing runner was called + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_duplicate_document_indexing_task_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary indexing runner calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + document_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent dataset + _duplicate_document_indexing_task(non_existent_dataset_id, document_ids) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["indexing_runner"].assert_not_called() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + mock_external_service_dependencies["index_processor"].clean.assert_not_called() + + def test_duplicate_document_indexing_task_document_not_found_in_dataset( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when some documents don't exist in the dataset. + + This test verifies: + - Only existing documents are processed + - Non-existent documents are ignored + - Indexing runner receives only valid documents + - Database state updates correctly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Mix existing and non-existent document IDs + fake = Faker() + existing_document_ids = [doc.id for doc in documents] + non_existent_document_ids = [fake.uuid4() for _ in range(2)] + all_document_ids = existing_document_ids + non_existent_document_ids + + # Act: Execute the task with mixed document IDs + _duplicate_document_indexing_task(dataset.id, all_document_ids) + + # Assert: Verify only existing documents were processed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only existing documents were updated + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in existing_document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with only existing documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 2 # Only existing documents + + def test_duplicate_document_indexing_task_indexing_runner_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of IndexingRunner exceptions. + + This test verifies: + - Exceptions from IndexingRunner are properly caught + - Task completes without raising exceptions + - Database session is properly closed + - Error logging occurs + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception( + "Indexing runner failed" + ) + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + # Re-query documents from database since _duplicate_document_indexing_task close the session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for sandbox plan batch upload limit. + + This test verifies: + - Sandbox plan batch upload limit enforcement + - Error handling for batch upload limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure sandbox plan with batch limit + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + + # Create more documents than sandbox plan allows (limit is 1) + fake = Faker() + extra_documents = [] + for i in range(2): # Total will be 5 documents (3 existing + 2 new) + document = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=i + 3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + extra_documents.append(document) + + db.session.commit() + all_documents = documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with too many documents for sandbox plan + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "batch upload" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for vector space limit. + + This test verifies: + - Vector space limit enforcement + - Error handling for vector space limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure TEAM plan with vector space limit exceeded + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit + + document_ids = [doc.id for doc in documents] # 3 documents will exceed limit + + # Act: Execute the task with documents that will exceed vector space limit + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "limit" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_with_empty_document_list( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of empty document list. + + This test verifies: + - Empty document list is handled gracefully + - No processing occurs + - No errors are raised + - Database session is properly closed + """ + # Arrange: Create test dataset + dataset, _ = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=0 + ) + document_ids = [] + + # Act: Execute the task with empty document list + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify IndexingRunner was called with empty list + # Note: The actual implementation does call run([]) with empty list + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) + + def test_deprecated_duplicate_document_indexing_task_delegates_to_core( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that deprecated duplicate_document_indexing_task delegates to core function. + + This test verifies: + - Deprecated function calls core _duplicate_document_indexing_task + - Proper parameter passing + - Backward compatibility + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the deprecated task + duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify core function was executed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_normal_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test normal_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management (pull tasks, delete key) works properly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the normal task + normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_priority_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test priority_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management works properly + - Same behavior as normal task with different queue assignment + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the priority task + priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_tenant_queue_wrapper_processes_next_tasks( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant queue wrapper processes next queued tasks. + + This test verifies: + - After completing current task, next tasks are pulled from queue + - Next tasks are executed correctly + - Task waiting time is set for next tasks + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Extract values before session detachment + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock tenant isolated queue to return next task + mock_queue = MagicMock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_queue.pull_tasks.return_value = [next_task] + mock_queue_class.return_value = mock_queue + + # Mock the task function to track calls + mock_task_func = MagicMock() + + # Act: Execute the wrapper function + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert: Verify next task was scheduled + mock_queue.pull_tasks.assert_called_once() + mock_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + mock_queue.delete_task_key.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 798fe091ab..b738646736 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(segment_ids, dataset.id, document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id) # Assert: Verify index processor was created but load was not called - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( @@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index d9f6dcc43c..025a0d8d70 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, @@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments: @pytest.fixture def sample_embedding_result(self): - """Create a sample TextEmbeddingResult for testing. + """Create a sample EmbeddingResult for testing. Returns: - TextEmbeddingResult: Mock embedding result with proper structure + EmbeddingResult: Mock embedding result with proper structure """ # Create normalized embedding vectors (dimension 1536 for ada-002) embedding_vector = np.random.randn(1536) @@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_vector], usage=usage, @@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments: latency=0.6, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=new_embeddings, usage=usage, @@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[valid_vector.tolist(), nan_vector], usage=usage, @@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[nan_vector], usage=usage, @@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage, ) - result_3_small = TextEmbeddingResult( + result_3_small = EmbeddingResult( model="text-embedding-3-small", embeddings=[normalized_3_small], usage=usage, @@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching: latency=0.4, ) - result_openai = TextEmbeddingResult( + result_openai = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_openai], usage=usage_openai, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation: latency=0.7, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage_ada, @@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation: latency=0.4, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases: latency=0.1, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases: latency=1.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases: latency=0.2, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases: ) # Model returns embeddings for all texts - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index d26e98db8d..c00fee8fe5 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -62,7 +62,7 @@ from core.indexing_runner import ( IndexingRunner, ) from core.model_runtime.entities.model_entities import ModelType -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule @@ -112,7 +112,7 @@ def create_mock_dataset_document( document_id: str | None = None, dataset_id: str | None = None, tenant_id: str | None = None, - doc_form: str = IndexType.PARAGRAPH_INDEX, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, data_source_type: str = "upload_file", doc_language: str = "English", ) -> Mock: @@ -133,8 +133,8 @@ def create_mock_dataset_document( Mock: A configured mock DatasetDocument object with all required attributes. Example: - >>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX) - >>> assert doc.doc_form == IndexType.QA_INDEX + >>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX) + >>> assert doc.doc_form == IndexStructureType.QA_INDEX """ doc = Mock(spec=DatasetDocument) doc.id = document_id or str(uuid.uuid4()) @@ -276,7 +276,7 @@ class TestIndexingRunnerExtract: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} return doc @@ -616,7 +616,7 @@ class TestIndexingRunnerLoad: doc = Mock(spec=DatasetDocument) doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -700,7 +700,7 @@ class TestIndexingRunnerLoad: """Test loading with parent-child index structure.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX sample_dataset.indexing_technique = "high_quality" # Add child documents @@ -775,7 +775,7 @@ class TestIndexingRunnerRun: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} @@ -802,6 +802,21 @@ class TestIndexingRunnerRun: mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule + # Mock current_user (Account) for _transform + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + # Setup db.session.query to return different results based on the model + def mock_query_side_effect(model): + mock_query_result = MagicMock() + if model.__name__ == "Dataset": + mock_query_result.filter_by.return_value.first.return_value = mock_dataset + elif model.__name__ == "Account": + mock_query_result.filter_by.return_value.first.return_value = mock_current_user + return mock_query_result + + mock_dependencies["db"].session.query.side_effect = mock_query_side_effect + # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.created_by = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments: """Test loading segments for parent-child index.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX # Add child documents for doc in sample_documents: @@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate: tenant_id=tenant_id, extract_settings=extract_settings, tmp_processing_rule={"mode": "automatic", "rules": {}}, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 4912884c55..ebe6c37818 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +def create_mock_model_instance(): + """Create a properly configured mock ModelInstance for reranking tests.""" + mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" + return mock_instance + + class TestRerankModelRunner: """Unit tests for RerankModelRunner. @@ -37,10 +49,23 @@ class TestRerankModelRunner: - Metadata preservation and score injection """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" return mock_instance @pytest.fixture @@ -803,7 +828,7 @@ class TestRerankRunnerFactory: - Parameters are forwarded to runner constructor """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner via factory runner = RerankRunnerFactory.create_rerank_runner( @@ -865,7 +890,7 @@ class TestRerankRunnerFactory: - String values are properly matched """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner using enum value runner = RerankRunnerFactory.create_rerank_runner( @@ -886,6 +911,13 @@ class TestRerankIntegration: - Real-world usage scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_model_reranking_full_workflow(self): """Test complete model-based reranking workflow. @@ -895,7 +927,7 @@ class TestRerankIntegration: - Top results are returned correctly """ # Arrange: Create mock model and documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -951,7 +983,7 @@ class TestRerankIntegration: - Normalization is consistent """ # Arrange: Create mock model with various scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -991,6 +1023,13 @@ class TestRerankEdgeCases: - Concurrent reranking scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_with_empty_metadata(self): """Test reranking when documents have empty metadata. @@ -1000,7 +1039,7 @@ class TestRerankEdgeCases: - Empty metadata documents are processed correctly """ # Arrange: Create documents with empty metadata - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1046,7 +1085,7 @@ class TestRerankEdgeCases: - Score comparison logic works at boundary """ # Arrange: Create mock with various scores including negatives - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1082,7 +1121,7 @@ class TestRerankEdgeCases: - No overflow or precision issues """ # Arrange: All documents with perfect scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1117,7 +1156,7 @@ class TestRerankEdgeCases: - Content encoding is preserved """ # Arrange: Documents with special characters - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1159,7 +1198,7 @@ class TestRerankEdgeCases: - Content is not truncated unexpectedly """ # Arrange: Documents with very long content - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() long_content = "This is a very long document. " * 1000 # ~30,000 characters mock_rerank_result = RerankResult( @@ -1196,7 +1235,7 @@ class TestRerankEdgeCases: - All documents are processed correctly """ # Arrange: Create 100 documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() num_docs = 100 # Create rerank results for all documents @@ -1287,7 +1326,7 @@ class TestRerankEdgeCases: - Documents can still be ranked """ # Arrange: Empty query - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1325,6 +1364,13 @@ class TestRerankPerformance: - Score calculation optimization """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_batch_processing(self): """Test that documents are processed in a single batch. @@ -1334,7 +1380,7 @@ class TestRerankPerformance: - Efficient batch processing """ # Arrange: Multiple documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)], @@ -1435,6 +1481,13 @@ class TestRerankErrorHandling: - Error propagation """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_model_invocation_error(self): """Test handling of model invocation errors. @@ -1444,7 +1497,7 @@ class TestRerankErrorHandling: - Error context is preserved """ # Arrange: Mock model that raises exception - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed") documents = [ @@ -1470,7 +1523,7 @@ class TestRerankErrorHandling: - Invalid results don't corrupt output """ # Arrange: Rerank result with invalid index - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 0163e42992..affd6c648f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -425,15 +425,15 @@ class TestRetrievalService: # ==================== Vector Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents): + def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic vector/semantic search functionality. This test validates the core vector search flow: 1. Dataset is retrieved from database - 2. embedding_search is called via ThreadPoolExecutor + 2. _retrieve is called via ThreadPoolExecutor 3. Documents are added to shared all_documents list 4. Results are returned to caller @@ -447,28 +447,28 @@ class TestRetrievalService: # Set up the mock dataset that will be "retrieved" from database mock_get_dataset.return_value = mock_dataset - # Create a side effect function that simulates embedding_search behavior - # In the real implementation, embedding_search: - # 1. Gets the dataset - # 2. Creates a Vector instance - # 3. Calls search_by_vector with embeddings - # 4. Extends all_documents with results - def side_effect_embedding_search( + # Create a side effect function that simulates _retrieve behavior + # _retrieve modifies the all_documents list in place + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - """Simulate embedding_search adding documents to the shared list.""" - all_documents.extend(sample_documents) + """Simulate _retrieve adding documents to the shared list.""" + if all_documents is not None: + all_documents.extend(sample_documents) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve # Define test parameters query = "What is Python?" # Natural language query @@ -481,7 +481,7 @@ class TestRetrievalService: # 1. Check if query is empty (early return if so) # 2. Get the dataset using _get_dataset # 3. Create ThreadPoolExecutor - # 4. Submit embedding_search task + # 4. Submit _retrieve task # 5. Wait for completion # 6. Return all_documents list results = RetrievalService.retrieve( @@ -502,15 +502,13 @@ class TestRetrievalService: # Verify documents maintain their scores (highest score first in sample_documents) assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents" - # Verify embedding_search was called exactly once + # Verify _retrieve was called exactly once # This confirms the search method was invoked by ThreadPoolExecutor - mock_embedding_search.assert_called_once() + mock_retrieve.assert_called_once() - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_document_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with document ID filtering. @@ -522,21 +520,25 @@ class TestRetrievalService: mock_get_dataset.return_value = mock_dataset filtered_docs = [sample_documents[0]] - def side_effect_embedding_search( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(filtered_docs) + if all_documents is not None: + all_documents.extend(filtered_docs) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve document_ids_filter = [sample_documents[0].metadata["document_id"]] # Act @@ -552,12 +554,12 @@ class TestRetrievalService: assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" # Verify document_ids_filter was passed - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["document_ids_filter"] == document_ids_filter - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search when no results match the query. @@ -567,8 +569,8 @@ class TestRetrievalService: """ # Arrange mock_get_dataset.return_value = mock_dataset - # embedding_search doesn't add anything to all_documents - mock_embedding_search.side_effect = lambda *args, **kwargs: None + # _retrieve doesn't add anything to all_documents + mock_retrieve.side_effect = lambda *args, **kwargs: None # Act results = RetrievalService.retrieve( @@ -583,9 +585,9 @@ class TestRetrievalService: # ==================== Keyword Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents): + def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic keyword search functionality. @@ -597,12 +599,25 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - def side_effect_keyword_search( - flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None + def side_effect_retrieve( + flask_app, + retrieval_method, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(sample_documents) + if all_documents is not None: + all_documents.extend(sample_documents) - mock_keyword_search.side_effect = side_effect_keyword_search + mock_retrieve.side_effect = side_effect_retrieve query = "Python programming" top_k = 3 @@ -618,7 +633,7 @@ class TestRetrievalService: # Assert assert len(results) == 3 assert all(isinstance(doc, Document) for doc in results) - mock_keyword_search.assert_called_once() + mock_retrieve.assert_called_once() @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -1147,11 +1162,9 @@ class TestRetrievalService: # ==================== Metadata Filtering Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_metadata_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with metadata-based document filtering. @@ -1166,21 +1179,25 @@ class TestRetrievalService: filtered_doc = sample_documents[0] filtered_doc.metadata["category"] = "programming" - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(filtered_doc) + if all_documents is not None: + all_documents.append(filtered_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve # Act results = RetrievalService.retrieve( @@ -1243,9 +1260,9 @@ class TestRetrievalService: # Assert assert results == [] - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that exceptions during retrieval are properly handled. @@ -1256,22 +1273,26 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - # Make embedding_search add an exception to the exceptions list + # Make _retrieve add an exception to the exceptions list def side_effect_with_exception( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - exceptions.append("Search failed") + if exceptions is not None: + exceptions.append("Search failed") - mock_embedding_search.side_effect = side_effect_with_exception + mock_retrieve.side_effect = side_effect_with_exception # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1286,9 +1307,9 @@ class TestRetrievalService: # ==================== Score Threshold Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search with score threshold filtering. @@ -1306,21 +1327,25 @@ class TestRetrievalService: provider="dify", ) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(high_score_doc) + if all_documents is not None: + all_documents.append(high_score_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve score_threshold = 0.8 @@ -1339,9 +1364,9 @@ class TestRetrievalService: # ==================== Top-K Limiting Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that retrieval respects top_k parameter. @@ -1362,22 +1387,26 @@ class TestRetrievalService: for i in range(10) ] - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): # Return only top_k documents - all_documents.extend(many_docs[:top_k]) + if all_documents is not None: + all_documents.extend(many_docs[:top_k]) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve top_k = 3 @@ -1390,9 +1419,9 @@ class TestRetrievalService: ) # Assert - # Verify top_k was passed to embedding_search - assert mock_embedding_search.called - call_kwargs = mock_embedding_search.call_args.kwargs + # Verify _retrieve was called + assert mock_retrieve.called + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["top_k"] == top_k # Verify we got the right number of results assert len(results) == top_k @@ -1421,11 +1450,9 @@ class TestRetrievalService: # ==================== Reranking Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_semantic_search_with_reranking( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test semantic search with reranking model. @@ -1439,22 +1466,26 @@ class TestRetrievalService: # Simulate reranking changing order reranked_docs = list(reversed(sample_documents)) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - # embedding_search handles reranking internally - all_documents.extend(reranked_docs) + # _retrieve handles reranking internally + if all_documents is not None: + all_documents.extend(reranked_docs) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve reranking_model = { "reranking_provider_name": "cohere", @@ -1473,7 +1504,7 @@ class TestRetrievalService: # Assert # For semantic search with reranking, reranking_model should be passed assert len(results) == 3 - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index 0f6b7e4ab6..47a5df92a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -1,3 +1,4 @@ +import json from unittest.mock import Mock, PropertyMock, patch import httpx @@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response): type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) response = Response(mock_response) assert response.is_file + + +# UTF-8 Encoding Tests +@pytest.mark.parametrize( + ("content_bytes", "expected_text", "description"), + [ + # Chinese UTF-8 bytes + ( + b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', + '{"message": "你好世界"}', + "Chinese characters UTF-8", + ), + # Japanese UTF-8 bytes + ( + b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', + '{"message": "こんにちは"}', + "Japanese characters UTF-8", + ), + # Korean UTF-8 bytes + ( + b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', + '{"message": "안녕하세요"}', + "Korean characters UTF-8", + ), + # Arabic UTF-8 + (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), + # European characters UTF-8 + (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), + # Simple ASCII + (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), + ], +) +def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): + """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" + mock_response.headers = {"content-type": "application/json; charset=utf-8"} + type(mock_response).content = PropertyMock(return_value=content_bytes) + # Mock httpx response.text to return something different (simulating potential encoding issues) + mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property + + response = Response(mock_response) + + # Our enhanced text property should decode properly using charset_normalizer + assert response.text == expected_text, ( + f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" + ) + + +def test_text_property_fallback_to_httpx(mock_response): + """Test that Response.text falls back to httpx.text when charset_normalizer fails""" + mock_response.headers = {"content-type": "application/json"} + + # Create malformed UTF-8 bytes + malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' + type(mock_response).content = PropertyMock(return_value=malformed_bytes) + + # Mock httpx.text to return some fallback value + fallback_text = '{"text": "fallback"}' + mock_response.text = fallback_text + + response = Response(mock_response) + + # Should fall back to httpx's text when charset_normalizer fails + assert response.text == fallback_text + + +@pytest.mark.parametrize( + ("json_content", "description"), + [ + # JSON with escaped Unicode (like Flask jsonify()) + ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), + # JSON with mixed escape sequences and UTF-8 + ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), + # JSON with complex escape sequences + ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), + ], +) +def test_text_property_with_escaped_unicode(mock_response, json_content, description): + """Test Response.text with JSON containing Unicode escape sequences""" + mock_response.headers = {"content-type": "application/json"} + + content_bytes = json_content.encode("utf-8") + type(mock_response).content = PropertyMock(return_value=content_bytes) + mock_response.text = json_content # httpx would return the same for valid UTF-8 + + response = Response(mock_response) + + # Should preserve the escape sequences (valid JSON) + assert response.text == json_content, f"Failed for {description}" + + # The text should be valid JSON that can be parsed back to proper Unicode + parsed = json.loads(response.text) + assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/services/document_indexing_task_proxy.py b/api/tests/unit_tests/services/document_indexing_task_proxy.py index 765c4b5e32..ff243b8dc3 100644 --- a/api/tests/unit_tests/services/document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/document_indexing_task_proxy.py @@ -117,7 +117,7 @@ import pytest from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy # ============================================================================ # Test Data Factory @@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy: # Features Property Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property(self, mock_feature_service): """ Test cached_property features. @@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy: mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property_with_different_tenants(self, mock_feature_service): """ Test features property with different tenant IDs. @@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy: # Direct Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """ Test _send_to_direct_queue method. @@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy: # Assert mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_direct_queue_with_priority_task(self, mock_task): """ Test _send_to_direct_queue with priority task function. @@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_single_document(self, mock_task): """ Test _send_to_direct_queue with single document ID. @@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_empty_documents(self, mock_task): """ Test _send_to_direct_queue with empty document_ids list. @@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy: # Tenant Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """ Test _send_to_tenant_queue when task key exists. @@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy: mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """ Test _send_to_tenant_queue when no task key exists. @@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy: proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_tenant_queue_with_priority_task(self, mock_task): """ Test _send_to_tenant_queue with priority task function. @@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_document_task_serialization(self, mock_task): """ Test DocumentTask serialization in _send_to_tenant_queue. @@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy: # Queue Type Selection Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_default_tenant_queue(self, mock_task): """ Test _send_to_default_tenant_queue method. @@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_tenant_queue(self, mock_task): """ Test _send_to_priority_tenant_queue method. @@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_direct_queue(self, mock_task): """ Test _send_to_priority_direct_queue method. @@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy: # Dispatch Logic Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with SANDBOX plan. @@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with TEAM plan. @@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with PROFESSIONAL plan. @@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """ Test _dispatch method when billing is disabled. @@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """ Test _dispatch method with empty plan string. @@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """ Test _dispatch method with None plan. @@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy: # Delay Method Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method(self, mock_feature_service): """ Test delay method integration. @@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_team_plan(self, mock_feature_service): """ Test delay method with TEAM plan. @@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_billing_disabled(self, mock_feature_service): """ Test delay method with billing disabled. @@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy: # Batch Operations Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_multiple_documents(self, mock_task): """ Test batch operation with multiple documents. @@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_large_batch(self, mock_task): """ Test batch operation with large batch of documents. @@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy: # Error Handling Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_task_delay_failure(self, mock_task): """ Test _send_to_direct_queue when task.delay() raises an exception. @@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Task delay failed"): proxy._send_to_direct_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_push_tasks_failure(self, mock_task): """ Test _send_to_tenant_queue when push_tasks raises an exception. @@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Push tasks failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task): """ Test _send_to_tenant_queue when set_task_waiting_time raises an exception. @@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Set waiting time failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_feature_service_failure(self, mock_feature_service): """ Test _dispatch when FeatureService.get_features raises an exception. @@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy: # Integration Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service): """ Test full flow for SANDBOX plan with tenant queue. @@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_team_plan(self, mock_task, mock_feature_service): """ Test full flow for TEAM plan with priority tenant queue. @@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_billing_disabled(self, mock_task, mock_feature_service): """ Test full flow for billing disabled (self-hosted/enterprise). diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py index d9183be9fb..98c30c3722 100644 --- a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy class DocumentIndexingTaskProxyTestDataFactory: @@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy: assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_features_property(self, mock_feature_service): """Test cached_property features.""" # Arrange @@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy: assert features1 is features2 # Should be the same instance due to caching mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """Test _send_to_direct_queue method.""" # Arrange @@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """Test _send_to_tenant_queue when task key exists.""" # Arrange @@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy: assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """Test _send_to_tenant_queue when no task key exists.""" # Arrange @@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy: ) proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_default_tenant_queue(self, mock_task): + def test_send_to_default_tenant_queue(self): """Test _send_to_default_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_default_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_tenant_queue(self, mock_task): + def test_send_to_priority_tenant_queue(self): """Test _send_to_priority_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_direct_queue(self, mock_task): + def test_send_to_priority_direct_queue(self): """Test _send_to_priority_direct_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_direct_queue() # Assert - proxy._send_to_direct_queue.assert_called_once_with(mock_task) + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with sandbox plan.""" # Arrange @@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with non-sandbox plan.""" # Arrange @@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy: # If billing enabled with non sandbox plan, should send to priority tenant queue proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """Test _dispatch method when billing is disabled.""" # Arrange @@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy: # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_delay_method(self, mock_feature_service): """Test delay method integration.""" # Arrange @@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy: assert task.dataset_id == dataset_id assert task.document_ids == document_ids - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """Test _dispatch method with empty plan string.""" # Arrange @@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """Test _dispatch method with None plan.""" # Arrange diff --git a/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..68bafe3d5e --- /dev/null +++ b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py @@ -0,0 +1,363 @@ +from unittest.mock import Mock, patch + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import ( + DuplicateDocumentIndexingTaskProxy, +) + + +class DuplicateDocumentIndexingTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_duplicate_document_task_proxy( + tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None + ) -> DuplicateDocumentIndexingTaskProxy: + """Create DuplicateDocumentIndexingTaskProxy instance for testing.""" + if document_ids is None: + document_ids = ["doc-1", "doc-2", "doc-3"] + return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + +class TestDuplicateDocumentIndexingTaskProxy: + """Test cases for DuplicateDocumentIndexingTaskProxy class.""" + + def test_initialization(self): + """Test DuplicateDocumentIndexingTaskProxy initialization.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2", "doc-3"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing" + + def test_queue_name(self): + """Test QUEUE_NAME class variable.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.QUEUE_NAME == "duplicate_document_indexing" + + def test_task_functions(self): + """Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task" + assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task" + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(mock_task) + + # Assert + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() + pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] + assert len(pushed_tasks) == 1 + assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask) + assert pushed_tasks[0]["tenant_id"] == "tenant-123" + assert pushed_tasks[0]["dataset_id"] == "dataset-456" + assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] + mock_task.delay.assert_not_called() + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + def test_send_to_default_tenant_queue(self): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_default_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) + + def test_send_to_priority_tenant_queue(self): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_priority_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + def test_send_to_priority_direct_queue(self): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_direct_queue = Mock() + + # Act + proxy._send_to_priority_direct_queue() + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing enabled with non sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_disabled(self, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_delay_method(self, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy.delay() + + # Assert + # If billing enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_empty_plan(self, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan="" + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_none_plan(self, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=None + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + def test_initialization_with_empty_document_ids(self): + """Test initialization with empty document_ids list.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_single_document_id(self): + """Test initialization with single document_id.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_large_batch(self): + """Test initialization with large batch of document IDs.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert len(proxy._document_ids) == 100 + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_professional_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with professional plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.PROFESSIONAL + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index c12ea2f7cb..e2d62583f8 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage. """ import json +import re from datetime import datetime from unittest.mock import MagicMock, Mock, patch @@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory): - """Test retrieval returns empty list on non-200 status.""" + def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory): + """Test that non-200 status code raises Exception with response text.""" # Arrange binding = factory.create_external_knowledge_binding_mock() api = factory.create_external_knowledge_api_mock() @@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval: mock_response = MagicMock() mock_response.status_code = 500 + mock_response.text = "Internal Server Error: Database connection failed" mock_process.return_value = mock_response - # Act - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - "tenant-123", "dataset-123", "query", {"top_k": 5} - ) + # Act & Assert + with pytest.raises(Exception, match="Internal Server Error: Database connection failed"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) - # Assert - assert result == [] + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (400, "Bad Request: Invalid query parameters"), + (401, "Unauthorized: Invalid API key"), + (403, "Forbidden: Access denied to resource"), + (404, "Not Found: Knowledge base not found"), + (429, "Too Many Requests: Rate limit exceeded"), + (500, "Internal Server Error: Database connection failed"), + (502, "Bad Gateway: External service unavailable"), + (503, "Service Unavailable: Maintenance mode"), + ], + ) + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_various_error_status_codes( + self, mock_db, mock_process, factory, status_code, error_message + ): + """Test that various error status codes raise exceptions with response text.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-123" + + binding = factory.create_external_knowledge_binding_mock( + dataset_id=dataset_id, external_knowledge_api_id="api-123" + ) + api = factory.create_external_knowledge_api_mock(api_id="api-123") + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = error_message + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError, match=re.escape(error_message)): + ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5}) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory): + """Test exception with empty response text.""" + # Arrange + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "" + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(Exception, match=""): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index b3b29fbe45..9d7599b8fe 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.dataset import Dataset, Document -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -138,7 +138,9 @@ class TestTaskEnqueuing: with patch.object(DocumentIndexingTaskProxy, "features") as mock_features: mock_features.billing.enabled = False - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -163,7 +165,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.SANDBOX - with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -187,7 +191,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -211,7 +217,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -1493,7 +1501,9 @@ class TestEdgeCases: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): # Act - Enqueue multiple tasks rapidly for doc_ids in document_ids_list: proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids) @@ -1898,7 +1908,7 @@ class TestRobustness: - Error is propagated appropriately """ # Arrange - with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features: + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features: # Simulate FeatureService failure mock_get_features.side_effect = Exception("Feature service unavailable") diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..0be6ea045e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,567 @@ +""" +Unit tests for duplicate document indexing tasks. + +This module tests the duplicate document indexing task functionality including: +- Task enqueuing to different queues (normal, priority, tenant-isolated) +- Batch processing of multiple duplicate documents +- Progress tracking through task lifecycle +- Error handling and retry mechanisms +- Cleanup of old document data before re-indexing +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, + _duplicate_document_indexing_task_with_tenant_queue, + duplicate_document_indexing_task, + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_ids(): + """Generate a list of document IDs for testing.""" + return [str(uuid.uuid4()) for _ in range(3)] + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_documents(document_ids, dataset_id): + """Create mock Document objects.""" + documents = [] + for doc_id in document_ids: + doc = Mock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + documents.append(doc) + return documents + + +@pytest.fixture +def mock_document_segments(document_ids): + """Create mock DocumentSegment objects.""" + segments = [] + for doc_id in document_ids: + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = doc_id + segment.index_node_id = f"node-{doc_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner_class.return_value = mock_runner + yield mock_runner + + +@pytest.fixture +def mock_feature_service(): + """Mock FeatureService.""" + with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service: + mock_features = Mock() + mock_features.billing = Mock() + mock_features.billing.enabled = False + mock_features.vector_space = Mock() + mock_features.vector_space.size = 0 + mock_features.vector_space.limit = 1000 + mock_service.get_features.return_value = mock_features + yield mock_service + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_tenant_isolated_queue(): + """Mock TenantIsolatedTaskQueue.""" + with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class: + mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) + mock_queue.pull_tasks.return_value = [] + mock_queue.delete_task_key = Mock() + mock_queue.set_task_waiting_time = Mock() + mock_queue_class.return_value = mock_queue + yield mock_queue + + +# ============================================================================ +# Tests for deprecated duplicate_document_indexing_task +# ============================================================================ + + +class TestDuplicateDocumentIndexingTask: + """Tests for the deprecated duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids): + """Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function.""" + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id): + """Test duplicate_document_indexing_task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + +# ============================================================================ +# Tests for _duplicate_document_indexing_task core function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskCore: + """Tests for the _duplicate_document_indexing_task core function.""" + + def test_successful_duplicate_document_indexing( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test successful duplicate document indexing flow.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify IndexingRunner was called + mock_indexing_runner.run.assert_called_once() + + # Verify all documents were set to parsing status + for doc in mock_documents: + assert doc.indexing_status == "parsing" + assert doc.processing_started_at is not None + + # Verify session operations + assert mock_db_session.commit.called + assert mock_db_session.close.called + + def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids): + """Test duplicate document indexing when dataset is not found.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session at least once + assert mock_db_session.close.called + + def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + dataset_id, + document_ids, + ): + """Test duplicate document indexing with billing enabled and sandbox plan.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.SANDBOX + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # For sandbox plan with multiple documents, should fail + mock_db_session.commit.assert_called() + + def test_duplicate_document_indexing_with_billing_limit_exceeded( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when billing limit is exceeded.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.TEAM + mock_features.vector_space.size = 990 + mock_features.vector_space.limit = 1000 + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should commit the session + assert mock_db_session.commit.called + # Should close the session + assert mock_db_session.close.called + + def test_duplicate_document_indexing_runner_error( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when IndexingRunner raises an error.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session even after error + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_document_is_paused( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when document is paused.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should handle DocumentIsPausedError gracefully + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_cleans_old_segments( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test that duplicate document indexing cleans old segments.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify clean was called for each document + assert mock_processor.clean.call_count == len(mock_documents) + + # Verify segments were deleted + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + +# ============================================================================ +# Tests for tenant queue wrapper function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskWithTenantQueue: + """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_calls_core_function( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper calls the core function.""" + # Arrange + mock_task_func = Mock() + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_deletes_key_when_no_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper deletes task key when no more tasks.""" + # Arrange + mock_task_func = Mock() + mock_tenant_isolated_queue.pull_tasks.return_value = [] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.delete_task_key.assert_called_once() + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_processes_next_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper processes next tasks from queue.""" + # Arrange + mock_task_func = Mock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_tenant_isolated_queue.pull_tasks.return_value = [next_task] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_handles_core_function_error( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper handles errors from core function.""" + # Arrange + mock_task_func = Mock() + mock_core_func.side_effect = Exception("Core function error") + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + # Should still check for next tasks even after error + mock_tenant_isolated_queue.pull_tasks.assert_called_once() + + +# ============================================================================ +# Tests for normal_duplicate_document_indexing_task +# ============================================================================ + + +class TestNormalDuplicateDocumentIndexingTask: + """Tests for normal_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that normal task calls tenant queue wrapper.""" + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_with_empty_document_ids( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test normal task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +# ============================================================================ +# Tests for priority_duplicate_document_indexing_task +# ============================================================================ + + +class TestPriorityDuplicateDocumentIndexingTask: + """Tests for priority_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that priority task calls tenant queue wrapper.""" + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_single_document( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with single document.""" + # Arrange + document_ids = ["doc-1"] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_large_batch( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with large batch of documents.""" + # Arrange + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 8bfc97ae63..8af47e8967 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -8,7 +8,9 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols [ ("...Hello, World!", "Hello, World!"), ("。测试中文标点", "测试中文标点"), - ("!@#Test symbols", "Test symbols"), + # Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols" + # The pattern intentionally excludes ! as per #11868 fix + ("@#Test symbols", "Test symbols"), ("Hello, World!", "Hello, World!"), ("", ""), (" ", " "), diff --git a/docker/.env.example b/docker/.env.example index b71c38e07a..80e87425c1 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -808,6 +808,19 @@ UPLOAD_FILE_BATCH_LIMIT=5 # Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll UPLOAD_FILE_EXTENSION_BLACKLIST= +# Maximum number of files allowed in a single chunk attachment, default 10. +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 + +# Maximum number of files allowed in a image batch upload operation +IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed image file size for attachments in megabytes, default 2. +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 + +# Timeout for downloading image attachments in seconds, default 60. +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 + + # ETL type, support: `dify`, `Unstructured` # `dify` Dify's proprietary file extraction scheme # `Unstructured` Unstructured.io file extraction scheme @@ -1415,4 +1428,4 @@ WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100 WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration -TENANT_ISOLATED_TASK_CONCURRENCY=1 \ No newline at end of file +TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 69bcd9dff8..f1061ef5f9 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -1,5 +1,24 @@ x-shared-env: &shared-api-worker-env services: + # Init container to fix permissions + init_permissions: + image: busybox:latest + command: + - sh + - -c + - | + FLAG_FILE="/app/api/storage/.init_permissions" + if [ -f "$${FLAG_FILE}" ]; then + echo "Permissions already initialized. Exiting." + exit 0 + fi + echo "Initializing permissions for /app/api/storage" + chown -R 1001:1001 /app/api/storage && touch "$${FLAG_FILE}" + echo "Permissions initialized. Exiting." + volumes: + - ./volumes/app/storage:/app/api/storage + restart: "no" + # API service api: image: langgenius/dify-api:1.10.1-fix.1 @@ -17,6 +36,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -54,6 +75,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -86,6 +109,8 @@ services: # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks. MODE: beat depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 407d240eeb..3e416c36c9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,6 +364,10 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} + SINGLE_CHUNK_ATTACHMENT_LIMIT: ${SINGLE_CHUNK_ATTACHMENT_LIMIT:-10} + IMAGE_FILE_BATCH_LIMIT: ${IMAGE_FILE_BATCH_LIMIT:-10} + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: ${ATTACHMENT_IMAGE_FILE_SIZE_LIMIT:-2} + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: ${ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT:-60} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} @@ -630,6 +634,25 @@ x-shared-env: &shared-api-worker-env TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} services: + # Init container to fix permissions + init_permissions: + image: busybox:latest + command: + - sh + - -c + - | + FLAG_FILE="/app/api/storage/.init_permissions" + if [ -f "$${FLAG_FILE}" ]; then + echo "Permissions already initialized. Exiting." + exit 0 + fi + echo "Initializing permissions for /app/api/storage" + chown -R 1001:1001 /app/api/storage && touch "$${FLAG_FILE}" + echo "Permissions initialized. Exiting." + volumes: + - ./volumes/app/storage:/app/api/storage + restart: "no" + # API service api: image: langgenius/dify-api:1.10.1-fix.1 @@ -647,6 +670,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -684,6 +709,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -716,6 +743,8 @@ services: # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks. MODE: beat depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false diff --git a/sdks/python-client/LICENSE b/sdks/python-client/LICENSE deleted file mode 100644 index 873e44b4bc..0000000000 --- a/sdks/python-client/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 LangGenius - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/sdks/python-client/MANIFEST.in b/sdks/python-client/MANIFEST.in deleted file mode 100644 index 34b7e8711c..0000000000 --- a/sdks/python-client/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include dify_client *.py -include README.md -include LICENSE diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md deleted file mode 100644 index ebfb5f5397..0000000000 --- a/sdks/python-client/README.md +++ /dev/null @@ -1,409 +0,0 @@ -# dify-client - -A Dify App Service-API Client, using for build a webapp by request Service-API - -## Usage - -First, install `dify-client` python sdk package: - -``` -pip install dify-client -``` - -### Synchronous Usage - -Write your code with sdk: - -- completion generate with `blocking` response_mode - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "What's the weather like today?"}, - response_mode="blocking", user="user_id") -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- completion using vision model, like gpt-4-vision - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "Describe the picture."}, - response_mode="blocking", user="user_id", files=files) -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- chat generate with `streaming` response_mode - -```python -import json -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Hello", user="user_id", response_mode="streaming") -chat_response.raise_for_status() - -for line in chat_response.iter_lines(decode_unicode=True): - line = line.split('data:', 1)[-1] - if line.strip(): - line = json.loads(line.strip()) - print(line.get('answer')) -``` - -- chat using vision model, like gpt-4-vision - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Describe the picture.", user="user_id", - response_mode="blocking", files=files) -chat_response.raise_for_status() - -result = chat_response.json() - -print(result.get("answer")) -``` - -- upload file when using vision model - -```python -from dify_client import DifyClient - -api_key = "your_api_key" - -# Initialize Client -dify_client = DifyClient(api_key) - -file_path = "your_image_file_path" -file_name = "panda.jpeg" -mime_type = "image/jpeg" - -with open(file_path, "rb") as file: - files = { - "file": (file_name, file, mime_type) - } - response = dify_client.file_upload("user_id", files) - - result = response.json() - print(f'upload_file_id: {result.get("id")}') -``` - -- Others - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize Client -client = ChatClient(api_key) - -# Get App parameters -parameters = client.get_application_parameters(user="user_id") -parameters.raise_for_status() - -print('[parameters]') -print(parameters.json()) - -# Get Conversation List (only for chat) -conversations = client.get_conversations(user="user_id") -conversations.raise_for_status() - -print('[conversations]') -print(conversations.json()) - -# Get Message List (only for chat) -messages = client.get_conversation_messages(user="user_id", conversation_id="conversation_id") -messages.raise_for_status() - -print('[messages]') -print(messages.json()) - -# Rename Conversation (only for chat) -rename_conversation_response = client.rename_conversation(conversation_id="conversation_id", - name="new_name", user="user_id") -rename_conversation_response.raise_for_status() - -print('[rename result]') -print(rename_conversation_response.json()) -``` - -- Using the Workflow Client - -```python -import json -import requests -from dify_client import WorkflowClient - -api_key = "your_api_key" - -# Initialize Workflow Client -client = WorkflowClient(api_key) - -# Prepare parameters for Workflow Client -user_id = "your_user_id" -context = "previous user interaction / metadata" -user_prompt = "What is the capital of France?" - -inputs = { - "context": context, - "user_prompt": user_prompt, - # Add other input fields expected by your workflow (e.g., additional context, task parameters) - -} - -# Set response mode (default: streaming) -response_mode = "blocking" - -# Run the workflow -response = client.run(inputs=inputs, response_mode=response_mode, user=user_id) -response.raise_for_status() - -# Parse result -result = json.loads(response.text) - -answer = result.get("data").get("outputs") - -print(answer["answer"]) - -``` - -- Dataset Management - -```python -from dify_client import KnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -# Use context manager to ensure proper resource cleanup -with KnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # Update dataset configuration - update_response = kb_client.update_dataset( - name="Updated Dataset Name", - description="Updated description", - indexing_technique="high_quality" - ) - update_response.raise_for_status() - print(update_response.json()) - - # Batch update document status - batch_response = kb_client.batch_update_document_status( - action="enable", - document_ids=["doc_id_1", "doc_id_2", "doc_id_3"] - ) - batch_response.raise_for_status() - print(batch_response.json()) -``` - -- Conversation Variables Management - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Use context manager to ensure proper resource cleanup -with ChatClient(api_key) as chat_client: - # Get all conversation variables - variables = chat_client.get_conversation_variables( - conversation_id="conversation_id", - user="user_id" - ) - variables.raise_for_status() - print(variables.json()) - - # Update a specific conversation variable - update_var = chat_client.update_conversation_variable( - conversation_id="conversation_id", - variable_id="variable_id", - value="new_value", - user="user_id" - ) - update_var.raise_for_status() - print(update_var.json()) -``` - -### Asynchronous Usage - -The SDK provides full async/await support for all API operations using `httpx.AsyncClient`. All async clients mirror their synchronous counterparts but require `await` for method calls. - -- async chat with `blocking` response_mode - -```python -import asyncio -from dify_client import AsyncChatClient - -api_key = "your_api_key" - -async def main(): - # Use async context manager for proper resource cleanup - async with AsyncChatClient(api_key) as client: - response = await client.create_chat_message( - inputs={}, - query="Hello, how are you?", - user="user_id", - response_mode="blocking" - ) - response.raise_for_status() - result = response.json() - print(result.get('answer')) - -# Run the async function -asyncio.run(main()) -``` - -- async completion with `streaming` response_mode - -```python -import asyncio -import json -from dify_client import AsyncCompletionClient - -api_key = "your_api_key" - -async def main(): - async with AsyncCompletionClient(api_key) as client: - response = await client.create_completion_message( - inputs={"query": "What's the weather?"}, - response_mode="streaming", - user="user_id" - ) - response.raise_for_status() - - # Stream the response - async for line in response.aiter_lines(): - if line.startswith('data:'): - data = line[5:].strip() - if data: - chunk = json.loads(data) - print(chunk.get('answer', ''), end='', flush=True) - -asyncio.run(main()) -``` - -- async workflow execution - -```python -import asyncio -from dify_client import AsyncWorkflowClient - -api_key = "your_api_key" - -async def main(): - async with AsyncWorkflowClient(api_key) as client: - response = await client.run( - inputs={"query": "What is machine learning?"}, - response_mode="blocking", - user="user_id" - ) - response.raise_for_status() - result = response.json() - print(result.get("data").get("outputs")) - -asyncio.run(main()) -``` - -- async dataset management - -```python -import asyncio -from dify_client import AsyncKnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -async def main(): - async with AsyncKnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = await kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # List documents - docs = await kb_client.list_documents(page=1, page_size=10) - docs.raise_for_status() - print(docs.json()) - -asyncio.run(main()) -``` - -**Benefits of Async Usage:** - -- **Better Performance**: Handle multiple concurrent API requests efficiently -- **Non-blocking I/O**: Don't block the event loop during network operations -- **Scalability**: Ideal for applications handling many simultaneous requests -- **Modern Python**: Leverages Python's native async/await syntax - -**Available Async Clients:** - -- `AsyncDifyClient` - Base async client -- `AsyncChatClient` - Async chat operations -- `AsyncCompletionClient` - Async completion operations -- `AsyncWorkflowClient` - Async workflow operations -- `AsyncKnowledgeBaseClient` - Async dataset/knowledge base operations -- `AsyncWorkspaceClient` - Async workspace operations - -``` -``` diff --git a/sdks/python-client/build.sh b/sdks/python-client/build.sh deleted file mode 100755 index 525f57c1ef..0000000000 --- a/sdks/python-client/build.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -set -e - -rm -rf build dist *.egg-info - -pip install setuptools wheel twine -python setup.py sdist bdist_wheel -twine upload dist/* diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py deleted file mode 100644 index ced093b20a..0000000000 --- a/sdks/python-client/dify_client/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, - WorkflowClient, - WorkspaceClient, -) - -from dify_client.async_client import ( - AsyncChatClient, - AsyncCompletionClient, - AsyncDifyClient, - AsyncKnowledgeBaseClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, -) - -__all__ = [ - # Synchronous clients - "ChatClient", - "CompletionClient", - "DifyClient", - "KnowledgeBaseClient", - "WorkflowClient", - "WorkspaceClient", - # Asynchronous clients - "AsyncChatClient", - "AsyncCompletionClient", - "AsyncDifyClient", - "AsyncKnowledgeBaseClient", - "AsyncWorkflowClient", - "AsyncWorkspaceClient", -] diff --git a/sdks/python-client/dify_client/async_client.py b/sdks/python-client/dify_client/async_client.py deleted file mode 100644 index 23126cf326..0000000000 --- a/sdks/python-client/dify_client/async_client.py +++ /dev/null @@ -1,2074 +0,0 @@ -"""Asynchronous Dify API client. - -This module provides async/await support for all Dify API operations using httpx.AsyncClient. -All client classes mirror their synchronous counterparts but require `await` for method calls. - -Example: - import asyncio - from dify_client import AsyncChatClient - - async def main(): - async with AsyncChatClient(api_key="your-key") as client: - response = await client.create_chat_message( - inputs={}, - query="Hello", - user="user-123" - ) - print(response.json()) - - asyncio.run(main()) -""" - -import json -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import aiofiles -import httpx - - -class AsyncDifyClient: - """Asynchronous Dify API client. - - This client uses httpx.AsyncClient for efficient async connection pooling. - It's recommended to use this client as a context manager: - - Example: - async with AsyncDifyClient(api_key="your-key") as client: - response = await client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - ): - """Initialize the async Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - """ - self.api_key = api_key - self.base_url = base_url - self._client = httpx.AsyncClient( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - async def __aenter__(self): - """Support async context manager protocol.""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting async context.""" - await self.aclose() - - async def aclose(self): - """Close the async HTTP client and release resources.""" - if hasattr(self, "_client"): - await self._client.aclose() - - async def _send_request( - self, - method: str, - endpoint: str, - json: Dict | None = None, - params: Dict | None = None, - stream: bool = False, - **kwargs, - ): - """Send an async HTTP request to the Dify API. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - response = await self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - return response - - async def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an async HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - response = await self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - return response - - async def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - """Send feedback for a message.""" - data = {"rating": rating, "user": user} - return await self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - async def get_application_parameters(self, user: str): - """Get application parameters.""" - params = {"user": user} - return await self._send_request("GET", "/parameters", params=params) - - async def file_upload(self, user: str, files: dict): - """Upload a file.""" - data = {"user": user} - return await self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - async def text_to_audio(self, text: str, user: str, streaming: bool = False): - """Convert text to audio.""" - data = {"text": text, "user": user, "streaming": streaming} - return await self._send_request("POST", "/text-to-audio", json=data) - - async def get_meta(self, user: str): - """Get metadata.""" - params = {"user": user} - return await self._send_request("GET", "/meta", params=params) - - async def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return await self._send_request("GET", "/info") - - async def get_app_site_info(self): - """Get application site information.""" - return await self._send_request("GET", "/site") - - async def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return await self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - async def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("GET", url) - - async def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("PUT", url, json=config_data) - - async def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("GET", url) - - async def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("POST", url, json=data) - - async def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return await self._send_request("DELETE", url) - - -class AsyncCompletionClient(AsyncDifyClient): - """Async client for Completion API operations.""" - - async def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict | None = None, - ): - """Create a completion message. - - Args: - inputs: Input variables for the completion - response_mode: Response mode ('blocking' or 'streaming') - user: User identifier - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return await self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class AsyncChatClient(AsyncDifyClient): - """Async client for Chat API operations.""" - - async def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict | None = None, - ): - """Create a chat message. - - Args: - inputs: Input variables for the chat - query: User query/message - user: User identifier - response_mode: Response mode ('blocking' or 'streaming') - conversation_id: Optional conversation ID for context - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return await self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - async def get_suggested(self, message_id: str, user: str): - """Get suggested questions for a message.""" - params = {"user": user} - return await self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - async def stop_message(self, task_id: str, user: str): - """Stop a running message generation.""" - data = {"user": user} - return await self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - async def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - """Get list of conversations.""" - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return await self._send_request("GET", "/conversations", params=params) - - async def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - """Get messages from a conversation.""" - params = { - "user": user, - "conversation_id": conversation_id, - "first_id": first_id, - "limit": limit, - } - return await self._send_request("GET", "/messages", params=params) - - async def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - """Rename a conversation.""" - data = {"name": name, "auto_generate": auto_generate, "user": user} - return await self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - async def delete_conversation(self, conversation_id: str, user: str): - """Delete a conversation.""" - data = {"user": user} - return await self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - async def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - """Convert audio to text.""" - data = {"user": user} - files = {"file": audio_file} - return await self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - async def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return await self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - async def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return await self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - async def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - async def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return await self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Enhanced Annotation APIs - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - # Conversation Variables APIs - async def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PATCH", url, json=data) - - # Enhanced Conversation Variable APIs - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, data=data) - - # Additional annotation methods for API parity - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - -class AsyncWorkflowClient(AsyncDifyClient): - """Async client for Workflow API operations.""" - - async def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a workflow.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request("POST", "/workflows/run", data) - - async def stop(self, task_id: str, user: str): - """Stop a running workflow task.""" - data = {"user": user} - return await self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - async def get_result(self, workflow_run_id: str): - """Get workflow run result.""" - return await self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - async def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = { - "page": page, - "limit": limit, - "keyword": keyword, - "status": status, - "created_at__before": created_at__before, - "created_at__after": created_at__after, - "created_by_end_user_session_id": created_by_end_user_session_id, - "created_by_account": created_by_account, - } - return await self._send_request("GET", "/workflows/logs", params=params) - - async def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - async def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("GET", url) - - async def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("PUT", url, json=workflow_data) - - async def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return await self._send_request("POST", url) - - async def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return await self._send_request("GET", url, params=params) - - -class AsyncWorkspaceClient(AsyncDifyClient): - """Async client for workspace-related operations.""" - - async def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_model_providers(self): - """Get all model providers.""" - return await self._send_request("GET", "/workspaces/current/model-providers") - - async def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return await self._send_request("GET", url) - - async def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return await self._send_request("POST", url, json=credentials) - - # File Management APIs - async def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return await self._send_request("GET", url) - - async def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return await self._send_request("GET", url) - - async def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return await self._send_request("DELETE", url) - - -class AsyncKnowledgeBaseClient(AsyncDifyClient): - """Async client for Knowledge Base API operations.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - timeout: float = 60.0, - ): - """Construct an AsyncKnowledgeBaseClient object. - - Args: - api_key: API key of Dify - base_url: Base URL of Dify API - dataset_id: ID of the dataset - timeout: Request timeout in seconds - """ - super().__init__(api_key=api_key, base_url=base_url, timeout=timeout) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - """Get the dataset ID, raise error if not set.""" - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - async def create_dataset(self, name: str, **kwargs): - """Create a new dataset.""" - return await self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - async def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - """List all datasets.""" - return await self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - async def create_document_by_text(self, name: str, text: str, extra_params: Dict | None = None, **kwargs): - """Create a document by text. - - Args: - name: Name of the document - text: Text content of the document - extra_params: Extra parameters for the API - - Returns: - Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict | None = None, - **kwargs, - ): - """Update a document by text.""" - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict | None = None, - ): - """Create a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def update_document_by_file(self, document_id: str, file_path: str, extra_params: Dict | None = None): - """Update a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def batch_indexing_status(self, batch_id: str, **kwargs): - """Get the status of the batch indexing.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return await self._send_request("GET", url, **kwargs) - - async def delete_dataset(self): - """Delete this dataset.""" - url = f"/datasets/{self._get_dataset_id()}" - return await self._send_request("DELETE", url) - - async def delete_document(self, document_id: str): - """Delete a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return await self._send_request("DELETE", url) - - async def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """Get a list of documents in this dataset.""" - params = { - "page": page, - "limit": page_size, - "keyword": keyword, - } - url = f"/datasets/{self._get_dataset_id()}/documents" - return await self._send_request("GET", url, params=params, **kwargs) - - async def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """Add segments to a document.""" - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return await self._send_request("POST", url, json=data, **kwargs) - - async def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """Query segments in this document. - - Args: - document_id: ID of the document - keyword: Query keyword (optional) - status: Status of the segment (optional, e.g., 'completed') - **kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - - Returns: - Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = { - "keyword": keyword, - "status": status, - } - if "params" in kwargs: - params.update(kwargs.pop("params")) - return await self._send_request("GET", url, params=params, **kwargs) - - async def delete_document_segment(self, document_id: str, segment_id: str): - """Delete a segment from a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("DELETE", url) - - async def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """Update a segment in a document.""" - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - async def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return await self._send_request("POST", url, json=data) - - async def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("GET", url) - - async def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("POST", url, json=metadata_data) - - async def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return await self._send_request("PATCH", url, json=metadata_data) - - async def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return await self._send_request("GET", url) - - async def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return await self._send_request("POST", url, json=data) - - async def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return await self._send_request("POST", url, json=data) - - # Dataset Tags APIs - async def list_dataset_tags(self): - """List all dataset tags.""" - return await self._send_request("GET", "/datasets/tags") - - async def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/binding", json=data) - - async def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/unbinding", json=data) - - async def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return await self._send_request("GET", url) - - # RAG Pipeline APIs - async def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return await self._send_request("GET", url, params=params) - - async def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return await self._send_request("POST", url, json=data, stream=True) - - async def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return await self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - async def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return await self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - async def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return await self._send_request("GET", url) - - async def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return await self._send_request("PATCH", url, json=data) - - async def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return await self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - - async def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return await self._send_request("POST", "/datasets/from-template", json=data) - - async def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return await self._send_request("POST", url, json=data) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, json=data) - - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - -class AsyncEnterpriseClient(AsyncDifyClient): - """Async Enterprise and Account Management APIs for Dify platform administration.""" - - async def get_account_info(self): - """Get current account information.""" - return await self._send_request("GET", "/account") - - async def update_account_info(self, account_data: Dict[str, Any]): - """Update account information.""" - return await self._send_request("PUT", "/account", json=account_data) - - # Member Management APIs - async def list_members(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List workspace members with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/members", params=params) - - async def invite_member(self, email: str, role: str, name: str | None = None): - """Invite a new member to the workspace.""" - data = {"email": email, "role": role} - if name: - data["name"] = name - return await self._send_request("POST", "/members/invite", json=data) - - async def get_member(self, member_id: str): - """Get detailed information about a specific member.""" - url = f"/members/{member_id}" - return await self._send_request("GET", url) - - async def update_member(self, member_id: str, member_data: Dict[str, Any]): - """Update member information.""" - url = f"/members/{member_id}" - return await self._send_request("PUT", url, json=member_data) - - async def remove_member(self, member_id: str): - """Remove a member from the workspace.""" - url = f"/members/{member_id}" - return await self._send_request("DELETE", url) - - async def deactivate_member(self, member_id: str): - """Deactivate a member account.""" - url = f"/members/{member_id}/deactivate" - return await self._send_request("POST", url) - - async def reactivate_member(self, member_id: str): - """Reactivate a deactivated member account.""" - url = f"/members/{member_id}/reactivate" - return await self._send_request("POST", url) - - # Role Management APIs - async def list_roles(self): - """List all available roles in the workspace.""" - return await self._send_request("GET", "/roles") - - async def create_role(self, name: str, description: str, permissions: List[str]): - """Create a new role with specified permissions.""" - data = {"name": name, "description": description, "permissions": permissions} - return await self._send_request("POST", "/roles", json=data) - - async def get_role(self, role_id: str): - """Get detailed information about a specific role.""" - url = f"/roles/{role_id}" - return await self._send_request("GET", url) - - async def update_role(self, role_id: str, role_data: Dict[str, Any]): - """Update role information.""" - url = f"/roles/{role_id}" - return await self._send_request("PUT", url, json=role_data) - - async def delete_role(self, role_id: str): - """Delete a role.""" - url = f"/roles/{role_id}" - return await self._send_request("DELETE", url) - - # Permission Management APIs - async def list_permissions(self): - """List all available permissions.""" - return await self._send_request("GET", "/permissions") - - async def get_role_permissions(self, role_id: str): - """Get permissions for a specific role.""" - url = f"/roles/{role_id}/permissions" - return await self._send_request("GET", url) - - async def update_role_permissions(self, role_id: str, permissions: List[str]): - """Update permissions for a role.""" - url = f"/roles/{role_id}/permissions" - data = {"permissions": permissions} - return await self._send_request("PUT", url, json=data) - - # Workspace Settings APIs - async def get_workspace_settings(self): - """Get workspace settings and configuration.""" - return await self._send_request("GET", "/workspace/settings") - - async def update_workspace_settings(self, settings_data: Dict[str, Any]): - """Update workspace settings.""" - return await self._send_request("PUT", "/workspace/settings", json=settings_data) - - async def get_workspace_statistics(self): - """Get workspace usage statistics.""" - return await self._send_request("GET", "/workspace/statistics") - - # Billing and Subscription APIs - async def get_billing_info(self): - """Get current billing information.""" - return await self._send_request("GET", "/billing") - - async def get_subscription_info(self): - """Get current subscription information.""" - return await self._send_request("GET", "/subscription") - - async def update_subscription(self, subscription_data: Dict[str, Any]): - """Update subscription settings.""" - return await self._send_request("PUT", "/subscription", json=subscription_data) - - async def get_billing_history(self, page: int = 1, limit: int = 20): - """Get billing history with pagination.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/billing/history", params=params) - - async def get_usage_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get usage metrics for a date range.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/usage/metrics", params=params) - - # Audit Logs APIs - async def get_audit_logs( - self, - page: int = 1, - limit: int = 20, - action: str | None = None, - user_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get audit logs with filtering options.""" - params = {"page": page, "limit": limit} - if action: - params["action"] = action - if user_id: - params["user_id"] = user_id - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/audit/logs", params=params) - - async def export_audit_logs(self, format: str = "csv", filters: Dict[str, Any] | None = None): - """Export audit logs in specified format.""" - params = {"format": format} - if filters: - params.update(filters) - return await self._send_request("GET", "/audit/logs/export", params=params) - - -class AsyncSecurityClient(AsyncDifyClient): - """Async Security and Access Control APIs for Dify platform security management.""" - - # API Key Management APIs - async def list_api_keys(self, page: int = 1, limit: int = 20, status: str | None = None): - """List all API keys with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/security/api-keys", params=params) - - async def create_api_key( - self, - name: str, - permissions: List[str], - expires_at: str | None = None, - description: str | None = None, - ): - """Create a new API key with specified permissions.""" - data = {"name": name, "permissions": permissions} - if expires_at: - data["expires_at"] = expires_at - if description: - data["description"] = description - return await self._send_request("POST", "/security/api-keys", json=data) - - async def get_api_key(self, key_id: str): - """Get detailed information about an API key.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("GET", url) - - async def update_api_key(self, key_id: str, key_data: Dict[str, Any]): - """Update API key information.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("PUT", url, json=key_data) - - async def revoke_api_key(self, key_id: str): - """Revoke an API key.""" - url = f"/security/api-keys/{key_id}/revoke" - return await self._send_request("POST", url) - - async def rotate_api_key(self, key_id: str): - """Rotate an API key (generate new key).""" - url = f"/security/api-keys/{key_id}/rotate" - return await self._send_request("POST", url) - - # Rate Limiting APIs - async def get_rate_limits(self): - """Get current rate limiting configuration.""" - return await self._send_request("GET", "/security/rate-limits") - - async def update_rate_limits(self, limits_config: Dict[str, Any]): - """Update rate limiting configuration.""" - return await self._send_request("PUT", "/security/rate-limits", json=limits_config) - - async def get_rate_limit_usage(self, timeframe: str = "1h"): - """Get rate limit usage statistics.""" - params = {"timeframe": timeframe} - return await self._send_request("GET", "/security/rate-limits/usage", params=params) - - # Access Control Lists APIs - async def list_access_policies(self, page: int = 1, limit: int = 20): - """List access control policies.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/security/access-policies", params=params) - - async def create_access_policy(self, policy_data: Dict[str, Any]): - """Create a new access control policy.""" - return await self._send_request("POST", "/security/access-policies", json=policy_data) - - async def get_access_policy(self, policy_id: str): - """Get detailed information about an access policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("GET", url) - - async def update_access_policy(self, policy_id: str, policy_data: Dict[str, Any]): - """Update an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("PUT", url, json=policy_data) - - async def delete_access_policy(self, policy_id: str): - """Delete an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("DELETE", url) - - # Security Settings APIs - async def get_security_settings(self): - """Get security configuration settings.""" - return await self._send_request("GET", "/security/settings") - - async def update_security_settings(self, settings_data: Dict[str, Any]): - """Update security configuration settings.""" - return await self._send_request("PUT", "/security/settings", json=settings_data) - - async def get_security_audit_logs( - self, - page: int = 1, - limit: int = 20, - event_type: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get security-specific audit logs.""" - params = {"page": page, "limit": limit} - if event_type: - params["event_type"] = event_type - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/security/audit-logs", params=params) - - # IP Whitelist/Blacklist APIs - async def get_ip_whitelist(self): - """Get IP whitelist configuration.""" - return await self._send_request("GET", "/security/ip-whitelist") - - async def update_ip_whitelist(self, ip_list: List[str], description: str | None = None): - """Update IP whitelist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-whitelist", json=data) - - async def get_ip_blacklist(self): - """Get IP blacklist configuration.""" - return await self._send_request("GET", "/security/ip-blacklist") - - async def update_ip_blacklist(self, ip_list: List[str], description: str | None = None): - """Update IP blacklist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-blacklist", json=data) - - # Authentication Settings APIs - async def get_auth_settings(self): - """Get authentication configuration settings.""" - return await self._send_request("GET", "/security/auth-settings") - - async def update_auth_settings(self, auth_data: Dict[str, Any]): - """Update authentication configuration settings.""" - return await self._send_request("PUT", "/security/auth-settings", json=auth_data) - - async def test_auth_configuration(self, auth_config: Dict[str, Any]): - """Test authentication configuration.""" - return await self._send_request("POST", "/security/auth-settings/test", json=auth_config) - - -class AsyncAnalyticsClient(AsyncDifyClient): - """Async Analytics and Monitoring APIs for Dify platform insights and metrics.""" - - # Usage Analytics APIs - async def get_usage_analytics( - self, - start_date: str, - end_date: str, - granularity: str = "day", - metrics: List[str] | None = None, - ): - """Get usage analytics for specified date range.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - if metrics: - params["metrics"] = ",".join(metrics) - return await self._send_request("GET", "/analytics/usage", params=params) - - async def get_app_usage_analytics(self, app_id: str, start_date: str, end_date: str, granularity: str = "day"): - """Get usage analytics for a specific app.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/analytics/apps/{app_id}/usage" - return await self._send_request("GET", url, params=params) - - async def get_user_analytics(self, start_date: str, end_date: str, user_segment: str | None = None): - """Get user analytics and behavior insights.""" - params = {"start_date": start_date, "end_date": end_date} - if user_segment: - params["user_segment"] = user_segment - return await self._send_request("GET", "/analytics/users", params=params) - - # Performance Metrics APIs - async def get_performance_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get performance metrics for the platform.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/analytics/performance", params=params) - - async def get_app_performance_metrics(self, app_id: str, start_date: str, end_date: str): - """Get performance metrics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/performance" - return await self._send_request("GET", url, params=params) - - async def get_model_performance_metrics(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get performance metrics for a specific model.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/models/{model_provider}/{model_name}/performance" - return await self._send_request("GET", url, params=params) - - # Cost Tracking APIs - async def get_cost_analytics(self, start_date: str, end_date: str, cost_type: str | None = None): - """Get cost analytics and breakdown.""" - params = {"start_date": start_date, "end_date": end_date} - if cost_type: - params["cost_type"] = cost_type - return await self._send_request("GET", "/analytics/costs", params=params) - - async def get_app_cost_analytics(self, app_id: str, start_date: str, end_date: str): - """Get cost analytics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/costs" - return await self._send_request("GET", url, params=params) - - async def get_cost_forecast(self, forecast_period: str = "30d"): - """Get cost forecast for specified period.""" - params = {"forecast_period": forecast_period} - return await self._send_request("GET", "/analytics/costs/forecast", params=params) - - # Real-time Monitoring APIs - async def get_real_time_metrics(self): - """Get real-time platform metrics.""" - return await self._send_request("GET", "/analytics/realtime") - - async def get_app_real_time_metrics(self, app_id: str): - """Get real-time metrics for a specific app.""" - url = f"/analytics/apps/{app_id}/realtime" - return await self._send_request("GET", url) - - async def get_system_health(self): - """Get overall system health status.""" - return await self._send_request("GET", "/analytics/health") - - # Custom Reports APIs - async def create_custom_report(self, report_config: Dict[str, Any]): - """Create a custom analytics report.""" - return await self._send_request("POST", "/analytics/reports", json=report_config) - - async def list_custom_reports(self, page: int = 1, limit: int = 20): - """List custom analytics reports.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/analytics/reports", params=params) - - async def get_custom_report(self, report_id: str): - """Get a specific custom report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("GET", url) - - async def update_custom_report(self, report_id: str, report_config: Dict[str, Any]): - """Update a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("PUT", url, json=report_config) - - async def delete_custom_report(self, report_id: str): - """Delete a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("DELETE", url) - - async def generate_report(self, report_id: str, format: str = "pdf"): - """Generate and download a custom report.""" - params = {"format": format} - url = f"/analytics/reports/{report_id}/generate" - return await self._send_request("GET", url, params=params) - - # Export APIs - async def export_analytics_data(self, data_type: str, start_date: str, end_date: str, format: str = "csv"): - """Export analytics data in specified format.""" - params = { - "data_type": data_type, - "start_date": start_date, - "end_date": end_date, - "format": format, - } - return await self._send_request("GET", "/analytics/export", params=params) - - -class AsyncIntegrationClient(AsyncDifyClient): - """Async Integration and Plugin APIs for Dify platform extensibility.""" - - # Webhook Management APIs - async def list_webhooks(self, page: int = 1, limit: int = 20, status: str | None = None): - """List webhooks with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/integrations/webhooks", params=params) - - async def create_webhook(self, webhook_data: Dict[str, Any]): - """Create a new webhook.""" - return await self._send_request("POST", "/integrations/webhooks", json=webhook_data) - - async def get_webhook(self, webhook_id: str): - """Get detailed information about a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("GET", url) - - async def update_webhook(self, webhook_id: str, webhook_data: Dict[str, Any]): - """Update webhook configuration.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("PUT", url, json=webhook_data) - - async def delete_webhook(self, webhook_id: str): - """Delete a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("DELETE", url) - - async def test_webhook(self, webhook_id: str): - """Test webhook delivery.""" - url = f"/integrations/webhooks/{webhook_id}/test" - return await self._send_request("POST", url) - - async def get_webhook_logs(self, webhook_id: str, page: int = 1, limit: int = 20): - """Get webhook delivery logs.""" - params = {"page": page, "limit": limit} - url = f"/integrations/webhooks/{webhook_id}/logs" - return await self._send_request("GET", url, params=params) - - # Plugin Management APIs - async def list_plugins(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available plugins.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/integrations/plugins", params=params) - - async def install_plugin(self, plugin_id: str, config: Dict[str, Any] | None = None): - """Install a plugin.""" - data = {"plugin_id": plugin_id} - if config: - data["config"] = config - return await self._send_request("POST", "/integrations/plugins/install", json=data) - - async def get_installed_plugin(self, installation_id: str): - """Get information about an installed plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("GET", url) - - async def update_plugin_config(self, installation_id: str, config: Dict[str, Any]): - """Update plugin configuration.""" - url = f"/integrations/plugins/{installation_id}/config" - return await self._send_request("PUT", url, json=config) - - async def uninstall_plugin(self, installation_id: str): - """Uninstall a plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("DELETE", url) - - async def enable_plugin(self, installation_id: str): - """Enable a plugin.""" - url = f"/integrations/plugins/{installation_id}/enable" - return await self._send_request("POST", url) - - async def disable_plugin(self, installation_id: str): - """Disable a plugin.""" - url = f"/integrations/plugins/{installation_id}/disable" - return await self._send_request("POST", url) - - # Import/Export APIs - async def export_app_data(self, app_id: str, format: str = "json", include_data: bool = True): - """Export application data.""" - params = {"format": format, "include_data": include_data} - url = f"/integrations/export/apps/{app_id}" - return await self._send_request("GET", url, params=params) - - async def import_app_data(self, import_data: Dict[str, Any]): - """Import application data.""" - return await self._send_request("POST", "/integrations/import/apps", json=import_data) - - async def get_import_status(self, import_id: str): - """Get import operation status.""" - url = f"/integrations/import/{import_id}/status" - return await self._send_request("GET", url) - - async def export_workspace_data(self, format: str = "json", include_data: bool = True): - """Export workspace data.""" - params = {"format": format, "include_data": include_data} - return await self._send_request("GET", "/integrations/export/workspace", params=params) - - async def import_workspace_data(self, import_data: Dict[str, Any]): - """Import workspace data.""" - return await self._send_request("POST", "/integrations/import/workspace", json=import_data) - - # Backup and Restore APIs - async def create_backup(self, backup_config: Dict[str, Any] | None = None): - """Create a system backup.""" - data = backup_config or {} - return await self._send_request("POST", "/integrations/backup/create", json=data) - - async def list_backups(self, page: int = 1, limit: int = 20): - """List available backups.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/integrations/backup", params=params) - - async def get_backup(self, backup_id: str): - """Get backup information.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("GET", url) - - async def restore_backup(self, backup_id: str, restore_config: Dict[str, Any] | None = None): - """Restore from backup.""" - data = restore_config or {} - url = f"/integrations/backup/{backup_id}/restore" - return await self._send_request("POST", url, json=data) - - async def delete_backup(self, backup_id: str): - """Delete a backup.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedModelClient(AsyncDifyClient): - """Async Advanced Model Management APIs for fine-tuning and custom deployments.""" - - # Fine-tuning Job Management APIs - async def list_fine_tuning_jobs( - self, - page: int = 1, - limit: int = 20, - status: str | None = None, - model_provider: str | None = None, - ): - """List fine-tuning jobs with filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - if model_provider: - params["model_provider"] = model_provider - return await self._send_request("GET", "/models/fine-tuning/jobs", params=params) - - async def create_fine_tuning_job(self, job_config: Dict[str, Any]): - """Create a new fine-tuning job.""" - return await self._send_request("POST", "/models/fine-tuning/jobs", json=job_config) - - async def get_fine_tuning_job(self, job_id: str): - """Get fine-tuning job details.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("GET", url) - - async def update_fine_tuning_job(self, job_id: str, job_config: Dict[str, Any]): - """Update fine-tuning job configuration.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("PUT", url, json=job_config) - - async def cancel_fine_tuning_job(self, job_id: str): - """Cancel a fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/cancel" - return await self._send_request("POST", url) - - async def resume_fine_tuning_job(self, job_id: str): - """Resume a paused fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/resume" - return await self._send_request("POST", url) - - async def get_fine_tuning_job_metrics(self, job_id: str): - """Get fine-tuning job training metrics.""" - url = f"/models/fine-tuning/jobs/{job_id}/metrics" - return await self._send_request("GET", url) - - async def get_fine_tuning_job_logs(self, job_id: str, page: int = 1, limit: int = 50): - """Get fine-tuning job logs.""" - params = {"page": page, "limit": limit} - url = f"/models/fine-tuning/jobs/{job_id}/logs" - return await self._send_request("GET", url, params=params) - - # Custom Model Deployment APIs - async def list_custom_deployments(self, page: int = 1, limit: int = 20, status: str | None = None): - """List custom model deployments.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/models/custom/deployments", params=params) - - async def create_custom_deployment(self, deployment_config: Dict[str, Any]): - """Create a custom model deployment.""" - return await self._send_request("POST", "/models/custom/deployments", json=deployment_config) - - async def get_custom_deployment(self, deployment_id: str): - """Get custom deployment details.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("GET", url) - - async def update_custom_deployment(self, deployment_id: str, deployment_config: Dict[str, Any]): - """Update custom deployment configuration.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("PUT", url, json=deployment_config) - - async def delete_custom_deployment(self, deployment_id: str): - """Delete a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("DELETE", url) - - async def scale_custom_deployment(self, deployment_id: str, scale_config: Dict[str, Any]): - """Scale custom deployment resources.""" - url = f"/models/custom/deployments/{deployment_id}/scale" - return await self._send_request("POST", url, json=scale_config) - - async def restart_custom_deployment(self, deployment_id: str): - """Restart a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}/restart" - return await self._send_request("POST", url) - - # Model Performance Monitoring APIs - async def get_model_performance_history( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get model performance history.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/models/{model_provider}/{model_name}/performance/history" - return await self._send_request("GET", url, params=params) - - async def get_model_health_metrics(self, model_provider: str, model_name: str): - """Get real-time model health metrics.""" - url = f"/models/{model_provider}/{model_name}/health" - return await self._send_request("GET", url) - - async def get_model_usage_stats( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - granularity: str = "day", - ): - """Get model usage statistics.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/models/{model_provider}/{model_name}/usage" - return await self._send_request("GET", url, params=params) - - async def get_model_cost_analysis(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get model cost analysis.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/models/{model_provider}/{model_name}/costs" - return await self._send_request("GET", url, params=params) - - # Model Versioning APIs - async def list_model_versions(self, model_provider: str, model_name: str, page: int = 1, limit: int = 20): - """List model versions.""" - params = {"page": page, "limit": limit} - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("GET", url, params=params) - - async def create_model_version(self, model_provider: str, model_name: str, version_config: Dict[str, Any]): - """Create a new model version.""" - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_model_version(self, model_provider: str, model_name: str, version_id: str): - """Get model version details.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}" - return await self._send_request("GET", url) - - async def promote_model_version(self, model_provider: str, model_name: str, version_id: str): - """Promote model version to production.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/promote" - return await self._send_request("POST", url) - - async def rollback_model_version(self, model_provider: str, model_name: str, version_id: str): - """Rollback to a specific model version.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # Model Registry APIs - async def list_registry_models(self, page: int = 1, limit: int = 20, filter: str | None = None): - """List models in registry.""" - params = {"page": page, "limit": limit} - if filter: - params["filter"] = filter - return await self._send_request("GET", "/models/registry", params=params) - - async def register_model(self, model_config: Dict[str, Any]): - """Register a new model in the registry.""" - return await self._send_request("POST", "/models/registry", json=model_config) - - async def get_registry_model(self, model_id: str): - """Get registered model details.""" - url = f"/models/registry/{model_id}" - return await self._send_request("GET", url) - - async def update_registry_model(self, model_id: str, model_config: Dict[str, Any]): - """Update registered model information.""" - url = f"/models/registry/{model_id}" - return await self._send_request("PUT", url, json=model_config) - - async def unregister_model(self, model_id: str): - """Unregister a model from the registry.""" - url = f"/models/registry/{model_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedAppClient(AsyncDifyClient): - """Async Advanced App Configuration APIs for comprehensive app management.""" - - # App Creation and Management APIs - async def create_app(self, app_config: Dict[str, Any]): - """Create a new application.""" - return await self._send_request("POST", "/apps", json=app_config) - - async def list_apps( - self, - page: int = 1, - limit: int = 20, - app_type: str | None = None, - status: str | None = None, - ): - """List applications with filtering.""" - params = {"page": page, "limit": limit} - if app_type: - params["app_type"] = app_type - if status: - params["status"] = status - return await self._send_request("GET", "/apps", params=params) - - async def get_app(self, app_id: str): - """Get detailed application information.""" - url = f"/apps/{app_id}" - return await self._send_request("GET", url) - - async def update_app(self, app_id: str, app_config: Dict[str, Any]): - """Update application configuration.""" - url = f"/apps/{app_id}" - return await self._send_request("PUT", url, json=app_config) - - async def delete_app(self, app_id: str): - """Delete an application.""" - url = f"/apps/{app_id}" - return await self._send_request("DELETE", url) - - async def duplicate_app(self, app_id: str, duplicate_config: Dict[str, Any]): - """Duplicate an application.""" - url = f"/apps/{app_id}/duplicate" - return await self._send_request("POST", url, json=duplicate_config) - - async def archive_app(self, app_id: str): - """Archive an application.""" - url = f"/apps/{app_id}/archive" - return await self._send_request("POST", url) - - async def restore_app(self, app_id: str): - """Restore an archived application.""" - url = f"/apps/{app_id}/restore" - return await self._send_request("POST", url) - - # App Publishing and Versioning APIs - async def publish_app(self, app_id: str, publish_config: Dict[str, Any] | None = None): - """Publish an application.""" - data = publish_config or {} - url = f"/apps/{app_id}/publish" - return await self._send_request("POST", url, json=data) - - async def unpublish_app(self, app_id: str): - """Unpublish an application.""" - url = f"/apps/{app_id}/unpublish" - return await self._send_request("POST", url) - - async def list_app_versions(self, app_id: str, page: int = 1, limit: int = 20): - """List application versions.""" - params = {"page": page, "limit": limit} - url = f"/apps/{app_id}/versions" - return await self._send_request("GET", url, params=params) - - async def create_app_version(self, app_id: str, version_config: Dict[str, Any]): - """Create a new application version.""" - url = f"/apps/{app_id}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_app_version(self, app_id: str, version_id: str): - """Get application version details.""" - url = f"/apps/{app_id}/versions/{version_id}" - return await self._send_request("GET", url) - - async def rollback_app_version(self, app_id: str, version_id: str): - """Rollback application to a specific version.""" - url = f"/apps/{app_id}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # App Template APIs - async def list_app_templates(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available app templates.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/apps/templates", params=params) - - async def get_app_template(self, template_id: str): - """Get app template details.""" - url = f"/apps/templates/{template_id}" - return await self._send_request("GET", url) - - async def create_app_from_template(self, template_id: str, app_config: Dict[str, Any]): - """Create an app from a template.""" - url = f"/apps/templates/{template_id}/create" - return await self._send_request("POST", url, json=app_config) - - async def create_custom_template(self, app_id: str, template_config: Dict[str, Any]): - """Create a custom template from an existing app.""" - url = f"/apps/{app_id}/create-template" - return await self._send_request("POST", url, json=template_config) - - # App Analytics and Metrics APIs - async def get_app_analytics( - self, - app_id: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get application analytics.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/apps/{app_id}/analytics" - return await self._send_request("GET", url, params=params) - - async def get_app_user_feedback(self, app_id: str, page: int = 1, limit: int = 20, rating: int | None = None): - """Get user feedback for an application.""" - params = {"page": page, "limit": limit} - if rating: - params["rating"] = rating - url = f"/apps/{app_id}/feedback" - return await self._send_request("GET", url, params=params) - - async def get_app_error_logs( - self, - app_id: str, - start_date: str, - end_date: str, - error_type: str | None = None, - page: int = 1, - limit: int = 20, - ): - """Get application error logs.""" - params = { - "start_date": start_date, - "end_date": end_date, - "page": page, - "limit": limit, - } - if error_type: - params["error_type"] = error_type - url = f"/apps/{app_id}/errors" - return await self._send_request("GET", url, params=params) - - # Advanced Configuration APIs - async def get_app_advanced_config(self, app_id: str): - """Get advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("GET", url) - - async def update_app_advanced_config(self, app_id: str, config: Dict[str, Any]): - """Update advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("PUT", url, json=config) - - async def get_app_environment_variables(self, app_id: str): - """Get application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("GET", url) - - async def update_app_environment_variables(self, app_id: str, variables: Dict[str, str]): - """Update application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("PUT", url, json=variables) - - async def get_app_resource_limits(self, app_id: str): - """Get application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("GET", url) - - async def update_app_resource_limits(self, app_id: str, limits: Dict[str, Any]): - """Update application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("PUT", url, json=limits) - - # App Integration APIs - async def get_app_integrations(self, app_id: str): - """Get application integrations.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("GET", url) - - async def add_app_integration(self, app_id: str, integration_config: Dict[str, Any]): - """Add integration to application.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("POST", url, json=integration_config) - - async def update_app_integration(self, app_id: str, integration_id: str, config: Dict[str, Any]): - """Update application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("PUT", url, json=config) - - async def remove_app_integration(self, app_id: str, integration_id: str): - """Remove integration from application.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("DELETE", url) - - async def test_app_integration(self, app_id: str, integration_id: str): - """Test application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}/test" - return await self._send_request("POST", url) diff --git a/sdks/python-client/dify_client/base_client.py b/sdks/python-client/dify_client/base_client.py deleted file mode 100644 index 0ad6e07b23..0000000000 --- a/sdks/python-client/dify_client/base_client.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Base client with common functionality for both sync and async clients.""" - -import json -import time -import logging -from typing import Dict, Callable, Optional - -try: - # Python 3.10+ - from typing import ParamSpec -except ImportError: - # Python < 3.10 - from typing_extensions import ParamSpec - -from urllib.parse import urljoin - -import httpx - -P = ParamSpec("P") - -from .exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, -) - - -class BaseClientMixin: - """Mixin class providing common functionality for Dify clients.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the base client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds - max_retries: Maximum number of retry attempts - retry_delay: Delay between retries in seconds - enable_logging: Enable detailed logging - """ - if not api_key: - raise ValidationError("API key is required") - - self.api_key = api_key - self.base_url = base_url.rstrip("/") - self.timeout = timeout - self.max_retries = max_retries - self.retry_delay = retry_delay - self.enable_logging = enable_logging - - # Setup logging - self.logger = logging.getLogger(f"dify_client.{self.__class__.__name__.lower()}") - if enable_logging and not self.logger.handlers: - # Create console handler with formatter - handler = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.logger.setLevel(logging.INFO) - self.enable_logging = True - else: - self.enable_logging = enable_logging - - def _get_headers(self, content_type: str = "application/json") -> Dict[str, str]: - """Get common request headers.""" - return { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": content_type, - "User-Agent": "dify-client-python/0.1.12", - } - - def _build_url(self, endpoint: str) -> str: - """Build full URL from endpoint.""" - return urljoin(self.base_url + "/", endpoint.lstrip("/")) - - def _handle_response(self, response: httpx.Response) -> httpx.Response: - """Handle HTTP response and raise appropriate exceptions.""" - try: - if response.status_code == 401: - raise AuthenticationError( - "Authentication failed. Check your API key.", - status_code=response.status_code, - response=response.json() if response.content else None, - ) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError( - "Rate limit exceeded. Please try again later.", - retry_after=int(retry_after) if retry_after else None, - ) - elif response.status_code >= 400: - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except: - message = f"HTTP {response.status_code}: {response.text}" - - raise APIError( - message, - status_code=response.status_code, - response=response.json() if response.content else None, - ) - - return response - - except json.JSONDecodeError: - raise APIError( - f"Invalid JSON response: {response.text}", - status_code=response.status_code, - ) - - def _retry_request( - self, - request_func: Callable[P, httpx.Response], - request_context: str | None = None, - *args: P.args, - **kwargs: P.kwargs, - ) -> httpx.Response: - """Retry a request with exponential backoff. - - Args: - request_func: Function that performs the HTTP request - request_context: Context description for logging (e.g., "GET /v1/messages") - *args: Positional arguments to pass to request_func - **kwargs: Keyword arguments to pass to request_func - - Returns: - httpx.Response: Successful response - - Raises: - NetworkError: On network failures after retries - TimeoutError: On timeout failures after retries - APIError: On API errors (4xx/5xx responses) - DifyClientError: On unexpected failures - """ - last_exception = None - - for attempt in range(self.max_retries + 1): - try: - response = request_func(*args, **kwargs) - return response # Let caller handle response processing - - except (httpx.NetworkError, httpx.TimeoutException) as e: - last_exception = e - context_msg = f" {request_context}" if request_context else "" - - if attempt < self.max_retries: - delay = self.retry_delay * (2**attempt) # Exponential backoff - self.logger.warning( - f"Request failed{context_msg} (attempt {attempt + 1}/{self.max_retries + 1}): {e}. " - f"Retrying in {delay:.2f} seconds..." - ) - time.sleep(delay) - else: - self.logger.error(f"Request failed{context_msg} after {self.max_retries + 1} attempts: {e}") - # Convert to custom exceptions - if isinstance(e, httpx.TimeoutException): - from .exceptions import TimeoutError - - raise TimeoutError(f"Request timed out after {self.max_retries} retries{context_msg}") from e - else: - from .exceptions import NetworkError - - raise NetworkError( - f"Network error after {self.max_retries} retries{context_msg}: {str(e)}" - ) from e - - if last_exception: - raise last_exception - raise DifyClientError("Request failed after retries") - - def _validate_params(self, **params) -> None: - """Validate request parameters.""" - for key, value in params.items(): - if value is None: - continue - - # String validations - if isinstance(value, str): - if not value.strip(): - raise ValidationError(f"Parameter '{key}' cannot be empty or whitespace only") - if len(value) > 10000: - raise ValidationError(f"Parameter '{key}' exceeds maximum length of 10000 characters") - - # List validations - elif isinstance(value, list): - if len(value) > 1000: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 1000 items") - - # Dictionary validations - elif isinstance(value, dict): - if len(value) > 100: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 100 items") - - # Type-specific validations - if key == "user" and not isinstance(value, str): - raise ValidationError(f"Parameter '{key}' must be a string") - elif key in ["page", "limit", "page_size"] and not isinstance(value, int): - raise ValidationError(f"Parameter '{key}' must be an integer") - elif key == "files" and not isinstance(value, (list, dict)): - raise ValidationError(f"Parameter '{key}' must be a list or dict") - elif key == "rating" and value not in ["like", "dislike"]: - raise ValidationError(f"Parameter '{key}' must be 'like' or 'dislike'") - - def _log_request(self, method: str, url: str, **kwargs) -> None: - """Log request details.""" - self.logger.info(f"Making {method} request to {url}") - if kwargs.get("json"): - self.logger.debug(f"Request body: {kwargs['json']}") - if kwargs.get("params"): - self.logger.debug(f"Query params: {kwargs['params']}") - - def _log_response(self, response: httpx.Response) -> None: - """Log response details.""" - self.logger.info(f"Received response: {response.status_code} ({len(response.content)} bytes)") diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py deleted file mode 100644 index cebdf6845c..0000000000 --- a/sdks/python-client/dify_client/client.py +++ /dev/null @@ -1,1267 +0,0 @@ -import json -import logging -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import httpx -from .base_client import BaseClientMixin -from .exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - FileUploadError, -) - - -class DifyClient(BaseClientMixin): - """Synchronous Dify API client. - - This client uses httpx.Client for efficient connection pooling and resource management. - It's recommended to use this client as a context manager: - - Example: - with DifyClient(api_key="your-key") as client: - response = client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - max_retries: Maximum number of retry attempts (default: 3) - retry_delay: Delay between retries in seconds (default: 1.0) - enable_logging: Whether to enable request logging (default: True) - """ - # Initialize base client functionality - BaseClientMixin.__init__(self, api_key, base_url, timeout, max_retries, retry_delay, enable_logging) - - self._client = httpx.Client( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - def __enter__(self): - """Support context manager protocol.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting context.""" - self.close() - - def close(self): - """Close the HTTP client and release resources.""" - if hasattr(self, "_client"): - self._client.close() - - def _send_request( - self, - method: str, - endpoint: str, - json: Dict[str, Any] | None = None, - params: Dict[str, Any] | None = None, - stream: bool = False, - **kwargs, - ): - """Send an HTTP request to the Dify API with retry logic. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - # Validate parameters - if json: - self._validate_params(**json) - if params: - self._validate_params(**params) - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - def make_request(): - """Inner function to perform the actual HTTP request.""" - # Log request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} request to {endpoint}") - # Debug logging for detailed information - if self.logger.isEnabledFor(logging.DEBUG): - if json: - self.logger.debug(f"Request body: {json}") - if params: - self.logger.debug(f"Request params: {params}") - - # httpx.Client automatically prepends base_url - response = self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received response: {response.status_code}") - - return response - - # Use the retry mechanism from base client - request_context = f"{method} {endpoint}" - response = self._retry_request(make_request, request_context) - - # Handle error responses (API errors don't retry) - self._handle_error_response(response) - - return response - - def _handle_error_response(self, response, is_upload_request: bool = False) -> None: - """Handle HTTP error responses and raise appropriate exceptions.""" - - if response.status_code < 400: - return # Success response - - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except (ValueError, KeyError): - message = f"HTTP {response.status_code}" - error_data = None - - # Log error response if logging is enabled - if self.enable_logging: - self.logger.error(f"API error: {response.status_code} - {message}") - - if response.status_code == 401: - raise AuthenticationError(message, response.status_code, error_data) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError(message, retry_after) - elif response.status_code == 422: - raise ValidationError(message, response.status_code, error_data) - elif response.status_code == 400: - # Check if this is a file upload error based on the URL or context - current_url = getattr(response, "url", "") or "" - if is_upload_request or "upload" in str(current_url).lower() or "files" in str(current_url).lower(): - raise FileUploadError(message, response.status_code, error_data) - else: - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 500: - # Server errors should raise APIError - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 400: - raise APIError(message, response.status_code, error_data) - - def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - # Log file upload request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} file upload request to {endpoint}") - self.logger.debug(f"Form data: {data}") - self.logger.debug(f"Files: {files}") - - response = self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received file upload response: {response.status_code}") - - # Handle error responses - self._handle_error_response(response, is_upload_request=True) - - return response - - def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - self._validate_params(message_id=message_id, rating=rating, user=user) - data = {"rating": rating, "user": user} - return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - def get_application_parameters(self, user: str): - params = {"user": user} - return self._send_request("GET", "/parameters", params=params) - - def file_upload(self, user: str, files: dict): - data = {"user": user} - return self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - def text_to_audio(self, text: str, user: str, streaming: bool = False): - data = {"text": text, "user": user, "streaming": streaming} - return self._send_request("POST", "/text-to-audio", json=data) - - def get_meta(self, user: str): - params = {"user": user} - return self._send_request("GET", "/meta", params=params) - - def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return self._send_request("GET", "/info") - - def get_app_site_info(self): - """Get application site information.""" - return self._send_request("GET", "/site") - - def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("GET", url) - - def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("PUT", url, json=config_data) - - def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return self._send_request("GET", url) - - def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return self._send_request("POST", url, json=data) - - def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return self._send_request("DELETE", url) - - -class CompletionClient(DifyClient): - def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, response_mode=response_mode, user=user) - - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class ChatClient(DifyClient): - def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if not isinstance(query, str) or not query.strip(): - raise ValidationError("query must be a non-empty string") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, query=query, user=user, response_mode=response_mode) - - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - def get_suggested(self, message_id: str, user: str): - params = {"user": user} - return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - def stop_message(self, task_id: str, user: str): - data = {"user": user} - return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return self._send_request("GET", "/conversations", params=params) - - def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - params = {"user": user} - - if conversation_id: - params["conversation_id"] = conversation_id - if first_id: - params["first_id"] = first_id - if limit: - params["limit"] = limit - - return self._send_request("GET", "/messages", params=params) - - def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - def delete_conversation(self, conversation_id: str, user: str): - data = {"user": user} - return self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - data = {"user": user} - files = {"file": audio_file} - return self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Conversation Variables APIs - def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return self._send_request("DELETE", url) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - # Enhanced Annotation APIs - def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return self._send_request("GET", url) - - def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations with pagination.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation_with_response(self, question: str, answer: str): - """Create an annotation with full response handling.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return self._send_request("PUT", url, json=data) - - -class WorkflowClient(DifyClient): - def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request("POST", "/workflows/run", data) - - def stop(self, task_id, user): - data = {"user": user} - return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - def get_result(self, workflow_run_id): - return self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - if status: - params["status"] = status - if created_at__before: - params["created_at__before"] = created_at__before - if created_at__after: - params["created_at__after"] = created_at__after - if created_by_end_user_session_id: - params["created_by_end_user_session_id"] = created_by_end_user_session_id - if created_by_account: - params["created_by_account"] = created_by_account - return self._send_request("GET", "/workflows/logs", params=params) - - def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("GET", url) - - def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("PUT", url, json=workflow_data) - - def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return self._send_request("POST", url) - - def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return self._send_request("GET", url, params=params) - - -class WorkspaceClient(DifyClient): - """Client for workspace-related operations.""" - - def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_model_providers(self): - """Get all model providers.""" - return self._send_request("GET", "/workspaces/current/model-providers") - - def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return self._send_request("GET", url) - - def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return self._send_request("POST", url, json=credentials) - - # File Management APIs - def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return self._send_request("GET", url) - - def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return self._send_request("GET", url) - - def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return self._send_request("DELETE", url) - - -class KnowledgeBaseClient(DifyClient): - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - ): - """ - Construct a KnowledgeBaseClient object. - - Args: - api_key (str): API key of Dify. - base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. - dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to - create a new dataset. or list datasets. otherwise you need to set this. - """ - super().__init__(api_key=api_key, base_url=base_url) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - def create_dataset(self, name: str, **kwargs): - return self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - def create_document_by_text(self, name, text, extra_params: Dict[str, Any] | None = None, **kwargs): - """ - Create a document by text. - - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict[str, Any] | None = None, - **kwargs, - ): - """ - Update a document by text. - - :param document_id: ID of the document - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict[str, Any] | None = None, - ): - """ - Create a document by file. - - :param file_path: Path to the file - :param original_document_id: pass this ID if you want to replace the original document (optional) - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def update_document_by_file( - self, - document_id: str, - file_path: str, - extra_params: Dict[str, Any] | None = None, - ): - """ - Update a document by file. - - :param document_id: ID of the document - :param file_path: Path to the file - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def batch_indexing_status(self, batch_id: str, **kwargs): - """ - Get the status of the batch indexing. - - :param batch_id: ID of the batch uploading - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return self._send_request("GET", url, **kwargs) - - def delete_dataset(self): - """ - Delete this dataset. - - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}" - return self._send_request("DELETE", url) - - def delete_document(self, document_id: str): - """ - Delete a document. - - :param document_id: ID of the document - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return self._send_request("DELETE", url) - - def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """ - Get a list of documents in this dataset. - - :return: Response from the API - """ - params = {} - if page is not None: - params["page"] = page - if page_size is not None: - params["limit"] = page_size - if keyword is not None: - params["keyword"] = keyword - url = f"/datasets/{self._get_dataset_id()}/documents" - return self._send_request("GET", url, params=params, **kwargs) - - def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """ - Add segments to a document. - - :param document_id: ID of the document - :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] - :return: Response from the API - """ - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return self._send_request("POST", url, json=data, **kwargs) - - def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """ - Query segments in this document. - - :param document_id: ID of the document - :param keyword: query keyword, optional - :param status: status of the segment, optional, e.g. completed - :param kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = {} - if keyword is not None: - params["keyword"] = keyword - if status is not None: - params["status"] = status - if "params" in kwargs: - params.update(kwargs.pop("params")) - return self._send_request("GET", url, params=params, **kwargs) - - def delete_document_segment(self, document_id: str, segment_id: str): - """ - Delete a segment from a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("DELETE", url) - - def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """ - Update a segment in a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} - :return: Response from the API - """ - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return self._send_request("POST", url, json=data) - - def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("GET", url) - - def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("POST", url, json=metadata_data) - - def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return self._send_request("PATCH", url, json=metadata_data) - - def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return self._send_request("GET", url) - - def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return self._send_request("POST", url, json=data) - - def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return self._send_request("POST", url, json=data) - - # Dataset Tags APIs - def list_dataset_tags(self): - """List all dataset tags.""" - return self._send_request("GET", "/datasets/tags") - - def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/binding", json=data) - - def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/unbinding", json=data) - - def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return self._send_request("GET", url) - - # RAG Pipeline APIs - def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return self._send_request("GET", url, params=params) - - def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return self._send_request("POST", url, json=data, stream=True) - - def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API containing dataset details including: - - name, description, permission - - indexing_technique, embedding_model, embedding_model_provider - - retrieval_model configuration - - document_count, word_count, app_count - - created_at, updated_at - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return self._send_request("GET", url) - - def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - # Build data dictionary with all possible parameters - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - # Filter out None values and merge with additional kwargs - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return self._send_request("PATCH", url, json=data) - - def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status (enable/disable/archive/unarchive). - - Args: - action: Action to perform on documents - - 'enable': Enable documents for retrieval - - 'disable': Disable documents from retrieval - - 'archive': Archive documents - - 'un_archive': Unarchive documents - document_ids: List of document IDs to update - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API with operation result - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return self._send_request("POST", "/datasets/from-template", json=data) - - def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return self._send_request("POST", url, json=data) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) diff --git a/sdks/python-client/dify_client/exceptions.py b/sdks/python-client/dify_client/exceptions.py deleted file mode 100644 index e7ba2ff4b2..0000000000 --- a/sdks/python-client/dify_client/exceptions.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Custom exceptions for the Dify client.""" - -from typing import Optional, Dict, Any - - -class DifyClientError(Exception): - """Base exception for all Dify client errors.""" - - def __init__(self, message: str, status_code: int | None = None, response: Dict[str, Any] | None = None): - super().__init__(message) - self.message = message - self.status_code = status_code - self.response = response - - -class APIError(DifyClientError): - """Raised when the API returns an error response.""" - - def __init__(self, message: str, status_code: int, response: Dict[str, Any] | None = None): - super().__init__(message, status_code, response) - self.status_code = status_code - - -class AuthenticationError(DifyClientError): - """Raised when authentication fails.""" - - pass - - -class RateLimitError(DifyClientError): - """Raised when rate limit is exceeded.""" - - def __init__(self, message: str = "Rate limit exceeded", retry_after: int | None = None): - super().__init__(message) - self.retry_after = retry_after - - -class ValidationError(DifyClientError): - """Raised when request validation fails.""" - - pass - - -class NetworkError(DifyClientError): - """Raised when network-related errors occur.""" - - pass - - -class TimeoutError(DifyClientError): - """Raised when request times out.""" - - pass - - -class FileUploadError(DifyClientError): - """Raised when file upload fails.""" - - pass - - -class DatasetError(DifyClientError): - """Raised when dataset operations fail.""" - - pass - - -class WorkflowError(DifyClientError): - """Raised when workflow operations fail.""" - - pass diff --git a/sdks/python-client/dify_client/models.py b/sdks/python-client/dify_client/models.py deleted file mode 100644 index 0321e9c3f4..0000000000 --- a/sdks/python-client/dify_client/models.py +++ /dev/null @@ -1,396 +0,0 @@ -"""Response models for the Dify client with proper type hints.""" - -from typing import Optional, List, Dict, Any, Literal, Union -from dataclasses import dataclass, field -from datetime import datetime - - -@dataclass -class BaseResponse: - """Base response model.""" - - success: bool = True - message: str | None = None - - -@dataclass -class ErrorResponse(BaseResponse): - """Error response model.""" - - error_code: str | None = None - details: Dict[str, Any] | None = None - success: bool = False - - -@dataclass -class FileInfo: - """File information model.""" - - id: str - name: str - size: int - mime_type: str - url: str | None = None - created_at: datetime | None = None - - -@dataclass -class MessageResponse(BaseResponse): - """Message response model.""" - - id: str = "" - answer: str = "" - conversation_id: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - files: List[Dict[str, Any]] | None = None - - -@dataclass -class ConversationResponse(BaseResponse): - """Conversation response model.""" - - id: str = "" - name: str = "" - inputs: Dict[str, Any] | None = None - status: str | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetResponse(BaseResponse): - """Dataset response model.""" - - id: str = "" - name: str = "" - description: str | None = None - permission: str | None = None - indexing_technique: str | None = None - embedding_model: str | None = None - embedding_model_provider: str | None = None - retrieval_model: Dict[str, Any] | None = None - document_count: int | None = None - word_count: int | None = None - app_count: int | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DocumentResponse(BaseResponse): - """Document response model.""" - - id: str = "" - name: str = "" - data_source_type: str | None = None - data_source_info: Dict[str, Any] | None = None - dataset_process_rule_id: str | None = None - batch: str | None = None - position: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - archived: bool | None = None - archived_reason: str | None = None - archived_at: float | None = None - archived_by: str | None = None - word_count: int | None = None - hit_count: int | None = None - doc_form: str | None = None - doc_metadata: Dict[str, Any] | None = None - created_at: float | None = None - updated_at: float | None = None - indexing_status: str | None = None - completed_at: float | None = None - paused_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class DocumentSegmentResponse(BaseResponse): - """Document segment response model.""" - - id: str = "" - position: int | None = None - document_id: str | None = None - content: str | None = None - answer: str | None = None - word_count: int | None = None - tokens: int | None = None - keywords: List[str] | None = None - index_node_id: str | None = None - index_node_hash: str | None = None - hit_count: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - status: str | None = None - created_by: str | None = None - created_at: float | None = None - indexing_at: float | None = None - completed_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class WorkflowRunResponse(BaseResponse): - """Workflow run response model.""" - - id: str = "" - workflow_id: str | None = None - status: Literal["running", "succeeded", "failed", "stopped"] | None = None - inputs: Dict[str, Any] | None = None - outputs: Dict[str, Any] | None = None - error: str | None = None - elapsed_time: float | None = None - total_tokens: int | None = None - total_steps: int | None = None - created_at: float | None = None - finished_at: float | None = None - - -@dataclass -class ApplicationParametersResponse(BaseResponse): - """Application parameters response model.""" - - opening_statement: str | None = None - suggested_questions: List[str] | None = None - speech_to_text: Dict[str, Any] | None = None - text_to_speech: Dict[str, Any] | None = None - retriever_resource: Dict[str, Any] | None = None - sensitive_word_avoidance: Dict[str, Any] | None = None - file_upload: Dict[str, Any] | None = None - system_parameters: Dict[str, Any] | None = None - user_input_form: List[Dict[str, Any]] | None = None - - -@dataclass -class AnnotationResponse(BaseResponse): - """Annotation response model.""" - - id: str = "" - question: str = "" - answer: str = "" - content: str | None = None - created_at: float | None = None - updated_at: float | None = None - created_by: str | None = None - updated_by: str | None = None - hit_count: int | None = None - - -@dataclass -class PaginatedResponse(BaseResponse): - """Paginated response model.""" - - data: List[Any] = field(default_factory=list) - has_more: bool = False - limit: int = 0 - total: int = 0 - page: int | None = None - - -@dataclass -class ConversationVariableResponse(BaseResponse): - """Conversation variable response model.""" - - conversation_id: str = "" - variables: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class FileUploadResponse(BaseResponse): - """File upload response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: float | None = None - - -@dataclass -class AudioResponse(BaseResponse): - """Audio generation/response model.""" - - audio: str | None = None # Base64 encoded audio data or URL - audio_url: str | None = None - duration: float | None = None - sample_rate: int | None = None - - -@dataclass -class SuggestedQuestionsResponse(BaseResponse): - """Suggested questions response model.""" - - message_id: str = "" - questions: List[str] = field(default_factory=list) - - -@dataclass -class AppInfoResponse(BaseResponse): - """App info response model.""" - - id: str = "" - name: str = "" - description: str | None = None - icon: str | None = None - icon_background: str | None = None - mode: str | None = None - tags: List[str] | None = None - enable_site: bool | None = None - enable_api: bool | None = None - api_token: str | None = None - - -@dataclass -class WorkspaceModelsResponse(BaseResponse): - """Workspace models response model.""" - - models: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class HitTestingResponse(BaseResponse): - """Hit testing response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class DatasetTagsResponse(BaseResponse): - """Dataset tags response model.""" - - tags: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class WorkflowLogsResponse(BaseResponse): - """Workflow logs response model.""" - - logs: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - page: int = 0 - limit: int = 0 - has_more: bool = False - - -@dataclass -class ModelProviderResponse(BaseResponse): - """Model provider response model.""" - - provider_name: str = "" - provider_type: str = "" - models: List[Dict[str, Any]] = field(default_factory=list) - is_enabled: bool = False - credentials: Dict[str, Any] | None = None - - -@dataclass -class FileInfoResponse(BaseResponse): - """File info response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - - -@dataclass -class WorkflowDraftResponse(BaseResponse): - """Workflow draft response model.""" - - id: str = "" - app_id: str = "" - draft_data: Dict[str, Any] = field(default_factory=dict) - version: int = 0 - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class ApiTokenResponse(BaseResponse): - """API token response model.""" - - id: str = "" - name: str = "" - token: str = "" - description: str | None = None - created_at: int | None = None - last_used_at: int | None = None - is_active: bool = True - - -@dataclass -class JobStatusResponse(BaseResponse): - """Job status response model.""" - - job_id: str = "" - job_status: str = "" - error_msg: str | None = None - progress: float | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetQueryResponse(BaseResponse): - """Dataset query response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - search_time: float | None = None - retrieval_model: Dict[str, Any] | None = None - - -@dataclass -class DatasetTemplateResponse(BaseResponse): - """Dataset template response model.""" - - template_name: str = "" - display_name: str = "" - description: str = "" - category: str = "" - icon: str | None = None - config_schema: Dict[str, Any] = field(default_factory=dict) - - -# Type aliases for common response types -ResponseType = Union[ - BaseResponse, - ErrorResponse, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -] diff --git a/sdks/python-client/examples/advanced_usage.py b/sdks/python-client/examples/advanced_usage.py deleted file mode 100644 index bc8720bef2..0000000000 --- a/sdks/python-client/examples/advanced_usage.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -Advanced usage examples for the Dify Python SDK. - -This example demonstrates: -- Error handling and retries -- Logging configuration -- Context managers -- Async usage -- File uploads -- Dataset management -""" - -import asyncio -import logging -from pathlib import Path - -from dify_client import ( - ChatClient, - CompletionClient, - AsyncChatClient, - KnowledgeBaseClient, - DifyClient, -) -from dify_client.exceptions import ( - APIError, - RateLimitError, - AuthenticationError, - DifyClientError, -) - - -def setup_logging(): - """Setup logging for the SDK.""" - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - -def example_chat_with_error_handling(): - """Example of chat with comprehensive error handling.""" - api_key = "your-api-key-here" - - try: - with ChatClient(api_key, enable_logging=True) as client: - # Simple chat message - response = client.create_chat_message( - inputs={}, query="Hello, how are you?", user="user-123", response_mode="blocking" - ) - - result = response.json() - print(f"Response: {result.get('answer')}") - - except AuthenticationError as e: - print(f"Authentication failed: {e}") - print("Please check your API key") - - except RateLimitError as e: - print(f"Rate limit exceeded: {e}") - if e.retry_after: - print(f"Retry after {e.retry_after} seconds") - - except APIError as e: - print(f"API error: {e.message}") - print(f"Status code: {e.status_code}") - - except DifyClientError as e: - print(f"Dify client error: {e}") - - except Exception as e: - print(f"Unexpected error: {e}") - - -def example_completion_with_files(): - """Example of completion with file upload.""" - api_key = "your-api-key-here" - - with CompletionClient(api_key) as client: - # Upload an image file first - file_path = "path/to/your/image.jpg" - - try: - with open(file_path, "rb") as f: - files = {"file": (Path(file_path).name, f, "image/jpeg")} - upload_response = client.file_upload("user-123", files) - upload_response.raise_for_status() - - file_id = upload_response.json().get("id") - print(f"File uploaded with ID: {file_id}") - - # Use the uploaded file in completion - files_list = [{"type": "image", "transfer_method": "local_file", "upload_file_id": file_id}] - - completion_response = client.create_completion_message( - inputs={"query": "Describe this image"}, response_mode="blocking", user="user-123", files=files_list - ) - - result = completion_response.json() - print(f"Completion result: {result.get('answer')}") - - except FileNotFoundError: - print(f"File not found: {file_path}") - except Exception as e: - print(f"Error during file upload/completion: {e}") - - -def example_dataset_management(): - """Example of dataset management operations.""" - api_key = "your-api-key-here" - - with KnowledgeBaseClient(api_key) as kb_client: - try: - # Create a new dataset - create_response = kb_client.create_dataset(name="My Test Dataset") - create_response.raise_for_status() - - dataset_id = create_response.json().get("id") - print(f"Created dataset with ID: {dataset_id}") - - # Create a client with the dataset ID - dataset_client = KnowledgeBaseClient(api_key, dataset_id=dataset_id) - - # Add a document by text - doc_response = dataset_client.create_document_by_text( - name="Test Document", text="This is a test document for the knowledge base." - ) - doc_response.raise_for_status() - - document_id = doc_response.json().get("document", {}).get("id") - print(f"Created document with ID: {document_id}") - - # List documents - list_response = dataset_client.list_documents() - list_response.raise_for_status() - - documents = list_response.json().get("data", []) - print(f"Dataset contains {len(documents)} documents") - - # Update dataset configuration - update_response = dataset_client.update_dataset( - name="Updated Dataset Name", description="Updated description", indexing_technique="high_quality" - ) - update_response.raise_for_status() - - print("Dataset updated successfully") - - except Exception as e: - print(f"Dataset management error: {e}") - - -async def example_async_chat(): - """Example of async chat usage.""" - api_key = "your-api-key-here" - - try: - async with AsyncChatClient(api_key) as client: - # Create chat message - response = await client.create_chat_message( - inputs={}, query="What's the weather like?", user="user-456", response_mode="blocking" - ) - - result = response.json() - print(f"Async response: {result.get('answer')}") - - # Get conversations - conversations = await client.get_conversations("user-456") - conversations.raise_for_status() - - conv_data = conversations.json() - print(f"Found {len(conv_data.get('data', []))} conversations") - - except Exception as e: - print(f"Async chat error: {e}") - - -def example_streaming_response(): - """Example of handling streaming responses.""" - api_key = "your-api-key-here" - - with ChatClient(api_key) as client: - try: - response = client.create_chat_message( - inputs={}, query="Tell me a story", user="user-789", response_mode="streaming" - ) - - print("Streaming response:") - for line in response.iter_lines(decode_unicode=True): - if line.startswith("data:"): - data = line[5:].strip() - if data: - import json - - try: - chunk = json.loads(data) - answer = chunk.get("answer", "") - if answer: - print(answer, end="", flush=True) - except json.JSONDecodeError: - continue - print() # New line after streaming - - except Exception as e: - print(f"Streaming error: {e}") - - -def example_application_info(): - """Example of getting application information.""" - api_key = "your-api-key-here" - - with DifyClient(api_key) as client: - try: - # Get app info - info_response = client.get_app_info() - info_response.raise_for_status() - - app_info = info_response.json() - print(f"App name: {app_info.get('name')}") - print(f"App mode: {app_info.get('mode')}") - print(f"App tags: {app_info.get('tags', [])}") - - # Get app parameters - params_response = client.get_application_parameters("user-123") - params_response.raise_for_status() - - params = params_response.json() - print(f"Opening statement: {params.get('opening_statement')}") - print(f"Suggested questions: {params.get('suggested_questions', [])}") - - except Exception as e: - print(f"App info error: {e}") - - -def main(): - """Run all examples.""" - setup_logging() - - print("=== Dify Python SDK Advanced Usage Examples ===\n") - - print("1. Chat with Error Handling:") - example_chat_with_error_handling() - print() - - print("2. Completion with Files:") - example_completion_with_files() - print() - - print("3. Dataset Management:") - example_dataset_management() - print() - - print("4. Async Chat:") - asyncio.run(example_async_chat()) - print() - - print("5. Streaming Response:") - example_streaming_response() - print() - - print("6. Application Info:") - example_application_info() - print() - - print("All examples completed!") - - -if __name__ == "__main__": - main() diff --git a/sdks/python-client/pyproject.toml b/sdks/python-client/pyproject.toml deleted file mode 100644 index a25cb9150c..0000000000 --- a/sdks/python-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "dify-client" -version = "0.1.12" -description = "A package for interacting with the Dify Service-API" -readme = "README.md" -requires-python = ">=3.10" -dependencies = [ - "httpx[http2]>=0.27.0", - "aiofiles>=23.0.0", -] -authors = [ - {name = "Dify", email = "hello@dify.ai"} -] -license = {text = "MIT"} -keywords = ["dify", "nlp", "ai", "language-processing"] -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] - -[project.urls] -Homepage = "https://github.com/langgenius/dify" - -[project.optional-dependencies] -dev = [ - "pytest>=7.0.0", - "pytest-asyncio>=0.21.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["dify_client"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" diff --git a/sdks/python-client/tests/__init__.py b/sdks/python-client/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sdks/python-client/tests/test_async_client.py b/sdks/python-client/tests/test_async_client.py deleted file mode 100644 index 4f5001866f..0000000000 --- a/sdks/python-client/tests/test_async_client.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for async client implementation in the Python SDK. - -This test validates the async/await functionality using httpx.AsyncClient -and ensures API parity with sync clients. -""" - -import unittest -from unittest.mock import Mock, patch, AsyncMock - -from dify_client.async_client import ( - AsyncDifyClient, - AsyncChatClient, - AsyncCompletionClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, - AsyncKnowledgeBaseClient, -) - - -class TestAsyncAPIParity(unittest.TestCase): - """Test that async clients have API parity with sync clients.""" - - def test_dify_client_api_parity(self): - """Test AsyncDifyClient has same methods as DifyClient.""" - from dify_client import DifyClient - - sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")} - - # aclose is async-specific, close is sync-specific - sync_methods.discard("close") - async_methods.discard("aclose") - - # Verify parity - self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient") - - def test_chat_client_api_parity(self): - """Test AsyncChatClient has same methods as ChatClient.""" - from dify_client import ChatClient - - sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient") - - def test_completion_client_api_parity(self): - """Test AsyncCompletionClient has same methods as CompletionClient.""" - from dify_client import CompletionClient - - sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient") - - def test_workflow_client_api_parity(self): - """Test AsyncWorkflowClient has same methods as WorkflowClient.""" - from dify_client import WorkflowClient - - sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient") - - def test_workspace_client_api_parity(self): - """Test AsyncWorkspaceClient has same methods as WorkspaceClient.""" - from dify_client import WorkspaceClient - - sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient") - - def test_knowledge_base_client_api_parity(self): - """Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient.""" - from dify_client import KnowledgeBaseClient - - sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient") - - -class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase): - """Test async client with mocked httpx.AsyncClient.""" - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_client_initialization(self, mock_httpx_async_client): - """Test async client initializes with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - client = AsyncDifyClient("test-key", "https://api.dify.ai/v1") - - # Verify httpx.AsyncClient was called - mock_httpx_async_client.assert_called_once() - self.assertEqual(client.api_key, "test-key") - - await client.aclose() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_context_manager(self, mock_httpx_async_client): - """Test async context manager works.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - self.assertEqual(client.api_key, "test-key") - - # Verify aclose was called - mock_client_instance.aclose.assert_called_once() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_send_request(self, mock_httpx_async_client): - """Test async _send_request method.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - mock_response.status_code = 200 - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - response = await client._send_request("GET", "/test") - - # Verify request was called - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify parameters - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_chat_client(self, mock_httpx_async_client): - """Test AsyncChatClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json = AsyncMock(return_value={"answer": "Hello!"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncChatClient("test-key") as client: - response = await client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_completion_client(self, mock_httpx_async_client): - """Test AsyncCompletionClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Response"}' - mock_response.json = AsyncMock(return_value={"answer": "Response"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncCompletionClient("test-key") as client: - response = await client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workflow_client(self, mock_httpx_async_client): - """Test AsyncWorkflowClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkflowClient("test-key") as client: - response = await client.run({"input": "test"}, "blocking", "user123") - data = await response.json() - self.assertEqual(data["result"], "success") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workspace_client(self, mock_httpx_async_client): - """Test AsyncWorkspaceClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": []}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkspaceClient("test-key") as client: - response = await client.get_available_models("llm") - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_knowledge_base_client(self, mock_httpx_async_client): - """Test AsyncKnowledgeBaseClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": [], "total": 0}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncKnowledgeBaseClient("test-key") as client: - response = await client.list_datasets() - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_all_async_client_classes(self, mock_httpx_async_client): - """Test all async client classes work with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - clients = [ - AsyncDifyClient("key"), - AsyncChatClient("key"), - AsyncCompletionClient("key"), - AsyncWorkflowClient("key"), - AsyncWorkspaceClient("key"), - AsyncKnowledgeBaseClient("key"), - ] - - # Verify httpx.AsyncClient was called for each - self.assertEqual(mock_httpx_async_client.call_count, 6) - - # Clean up - for client in clients: - await client.aclose() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py deleted file mode 100644 index b0d2f8ba23..0000000000 --- a/sdks/python-client/tests/test_client.py +++ /dev/null @@ -1,489 +0,0 @@ -import os -import time -import unittest -from unittest.mock import Mock, patch, mock_open - -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, -) - -API_KEY = os.environ.get("API_KEY") -APP_ID = os.environ.get("APP_ID") -API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") -FILE_PATH_BASE = os.path.dirname(__file__) - - -class TestKnowledgeBaseClient(unittest.TestCase): - def setUp(self): - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) - self.dataset_id = "test-dataset-id" - self.document_id = "test-document-id" - self.segment_id = "test-segment-id" - self.batch_id = "test-batch-id" - - def _get_dataset_kb_client(self): - return KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) - - @patch("dify_client.client.httpx.Client") - def test_001_create_dataset(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.json.return_value = {"id": self.dataset_id, "name": "test_dataset"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Re-create client with mocked httpx - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - - response = self.knowledge_base_client.create_dataset(name="test_dataset") - data = response.json() - self.assertIn("id", data) - self.assertEqual("test_dataset", data["name"]) - - # the following tests require to be executed in order because they use - # the dataset/document/segment ids from the previous test - self._test_002_list_datasets() - self._test_003_create_document_by_text() - self._test_004_update_document_by_text() - self._test_006_update_document_by_file() - self._test_007_list_documents() - self._test_008_delete_document() - self._test_009_create_document_by_file() - self._test_010_add_segments() - self._test_011_query_segments() - self._test_012_update_document_segment() - self._test_013_delete_document_segment() - self._test_014_delete_dataset() - - def _test_002_list_datasets(self): - # Mock the response - using the already mocked client from test_001_create_dataset - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - self.knowledge_base_client._client.request.return_value = mock_response - - response = self.knowledge_base_client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - def _test_003_create_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_text("test_document", "test_text") - data = response.json() - self.assertIn("document", data) - - def _test_004_update_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_006_update_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_007_list_documents(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.list_documents() - data = response.json() - self.assertIn("data", data) - - def _test_008_delete_document(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document(self.document_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_009_create_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_file(self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - - def _test_010_add_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_011_query_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.query_segments(self.document_id) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_012_update_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": {"id": self.segment_id, "content": "test text segment 1 updated"}} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_segment( - self.document_id, - self.segment_id, - {"content": "test text segment 1 updated"}, - ) - data = response.json() - self.assertIn("data", data) - self.assertEqual("test text segment 1 updated", data["data"]["content"]) - - def _test_013_delete_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document_segment(self.document_id, self.segment_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_014_delete_dataset(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.status_code = 204 - client._client.request.return_value = mock_response - - response = client.delete_dataset() - self.assertEqual(204, response.status_code) - - -class TestChatClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.chat_client = ChatClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.create_chat_message({}, "Hello, World!", "test_user") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test image description."}' - mock_response.json.return_value = {"answer": "I can see this is a test image description."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test uploaded image."}' - mock_response.json.return_value = {"answer": "I can see this is a test uploaded image."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversation_messages(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Here are the conversation messages."}' - mock_response.json.return_value = {"answer": "Here are the conversation messages."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversation_messages("test_user", "test-conversation-id") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversations(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"data": [{"id": "conv1", "name": "Test Conversation"}]}' - mock_response.json.return_value = {"data": [{"id": "conv1", "name": "Test Conversation"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversations("test_user") - self.assertIn("data", response.text) - - -class TestCompletionClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.completion_client = CompletionClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "This is a test completion response."}' - mock_response.json.return_value = {"answer": "This is a test completion response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "The weather today is sunny with a temperature of 75°F."}' - mock_response.json.return_value = {"answer": "The weather today is sunny with a temperature of 75°F."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - response = completion_client.create_completion_message( - {"query": "What's the weather like today?"}, "blocking", "test_user" - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test uploaded image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test uploaded image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - -class TestDifyClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.dify_client = DifyClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"result": "success"}' - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_message_feedback(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"success": true}' - mock_response.json.return_value = {"success": True} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.message_feedback("test-message-id", "like", "test_user") - self.assertIn("success", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_application_parameters(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"user_input_form": [{"field": "text", "label": "Input"}]}' - mock_response.json.return_value = {"user_input_form": [{"field": "text", "label": "Input"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.get_application_parameters("test_user") - self.assertIn("user_input_form", response.text) - - @patch("dify_client.client.httpx.Client") - @patch("builtins.open", new_callable=mock_open, read_data=b"fake image data") - def test_file_upload(self, mock_file_open, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"name": "panda.jpeg", "id": "test-file-id"}' - mock_response.json.return_value = {"name": "panda.jpeg", "id": "test-file-id"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - file_path = "/path/to/test/panda.jpeg" - file_name = "panda.jpeg" - mime_type = "image/jpeg" - - with open(file_path, "rb") as file: - files = {"file": (file_name, file, mime_type)} - response = dify_client.file_upload("test_user", files) - self.assertIn("name", response.text) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_exceptions.py b/sdks/python-client/tests/test_exceptions.py deleted file mode 100644 index eb44895749..0000000000 --- a/sdks/python-client/tests/test_exceptions.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Tests for custom exceptions.""" - -import unittest -from dify_client.exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, - DatasetError, - WorkflowError, -) - - -class TestExceptions(unittest.TestCase): - """Test custom exception classes.""" - - def test_base_exception(self): - """Test base DifyClientError.""" - error = DifyClientError("Test message", 500, {"error": "details"}) - self.assertEqual(str(error), "Test message") - self.assertEqual(error.status_code, 500) - self.assertEqual(error.response, {"error": "details"}) - - def test_api_error(self): - """Test APIError.""" - error = APIError("API failed", 400) - self.assertEqual(error.status_code, 400) - self.assertEqual(error.message, "API failed") - - def test_authentication_error(self): - """Test AuthenticationError.""" - error = AuthenticationError("Invalid API key") - self.assertEqual(str(error), "Invalid API key") - - def test_rate_limit_error(self): - """Test RateLimitError.""" - error = RateLimitError("Rate limited", retry_after=60) - self.assertEqual(error.retry_after, 60) - - error_default = RateLimitError() - self.assertEqual(error_default.retry_after, None) - - def test_validation_error(self): - """Test ValidationError.""" - error = ValidationError("Invalid parameter") - self.assertEqual(str(error), "Invalid parameter") - - def test_network_error(self): - """Test NetworkError.""" - error = NetworkError("Connection failed") - self.assertEqual(str(error), "Connection failed") - - def test_timeout_error(self): - """Test TimeoutError.""" - error = TimeoutError("Request timed out") - self.assertEqual(str(error), "Request timed out") - - def test_file_upload_error(self): - """Test FileUploadError.""" - error = FileUploadError("Upload failed") - self.assertEqual(str(error), "Upload failed") - - def test_dataset_error(self): - """Test DatasetError.""" - error = DatasetError("Dataset operation failed") - self.assertEqual(str(error), "Dataset operation failed") - - def test_workflow_error(self): - """Test WorkflowError.""" - error = WorkflowError("Workflow failed") - self.assertEqual(str(error), "Workflow failed") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_httpx_migration.py b/sdks/python-client/tests/test_httpx_migration.py deleted file mode 100644 index cf26de6eba..0000000000 --- a/sdks/python-client/tests/test_httpx_migration.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for httpx migration in the Python SDK. - -This test validates that the migration from requests to httpx maintains -backward compatibility and proper resource management. -""" - -import unittest -from unittest.mock import Mock, patch - -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - WorkspaceClient, - KnowledgeBaseClient, -) - - -class TestHttpxMigrationMocked(unittest.TestCase): - """Test cases for httpx migration with mocked requests.""" - - def setUp(self): - """Set up test fixtures.""" - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - - @patch("dify_client.client.httpx.Client") - def test_client_initialization(self, mock_httpx_client): - """Test that client initializes with httpx.Client.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - - # Verify httpx.Client was called with correct parameters - mock_httpx_client.assert_called_once() - call_kwargs = mock_httpx_client.call_args[1] - self.assertEqual(call_kwargs["base_url"], self.base_url) - - # Verify client properties - self.assertEqual(client.api_key, self.api_key) - self.assertEqual(client.base_url, self.base_url) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_context_manager_support(self, mock_httpx_client): - """Test that client works as context manager.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client: - self.assertEqual(client.api_key, self.api_key) - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_manual_close(self, mock_httpx_client): - """Test manual close() method.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - client.close() - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_send_request_httpx_compatibility(self, mock_httpx_client): - """Test _send_request uses httpx.Client.request properly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test-endpoint") - - # Verify httpx.Client.request was called correctly - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify method and endpoint - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test-endpoint") - - # Verify headers contain authorization - headers = call_args[1]["headers"] - self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}") - self.assertEqual(headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_response_compatibility(self, mock_httpx_client): - """Test httpx.Response is compatible with requests.Response API.""" - mock_response = Mock() - mock_response.json.return_value = {"key": "value"} - mock_response.text = '{"key": "value"}' - mock_response.content = b'{"key": "value"}' - mock_response.status_code = 200 - mock_response.headers = {"Content-Type": "application/json"} - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test") - - # Verify all common response methods work - self.assertEqual(response.json(), {"key": "value"}) - self.assertEqual(response.text, '{"key": "value"}') - self.assertEqual(response.content, b'{"key": "value"}') - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_all_client_classes_use_httpx(self, mock_httpx_client): - """Test that all client classes properly use httpx.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - clients = [ - DifyClient(self.api_key, self.base_url), - ChatClient(self.api_key, self.base_url), - CompletionClient(self.api_key, self.base_url), - WorkflowClient(self.api_key, self.base_url), - WorkspaceClient(self.api_key, self.base_url), - KnowledgeBaseClient(self.api_key, self.base_url), - ] - - # Verify httpx.Client was called for each client - self.assertEqual(mock_httpx_client.call_count, 6) - - # Clean up - for client in clients: - client.close() - - @patch("dify_client.client.httpx.Client") - def test_json_parameter_handling(self, mock_httpx_client): - """Test that json parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_data = {"key": "value", "number": 123} - - client._send_request("POST", "/test", json=test_data) - - # Verify json parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["json"], test_data) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_params_parameter_handling(self, mock_httpx_client): - """Test that params parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_params = {"page": 1, "limit": 20} - - client._send_request("GET", "/test", params=test_params) - - # Verify params parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["params"], test_params) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_inheritance_chain(self, mock_httpx_client): - """Test that inheritance chain is maintained.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - # ChatClient inherits from DifyClient - chat_client = ChatClient(self.api_key, self.base_url) - self.assertIsInstance(chat_client, DifyClient) - - # CompletionClient inherits from DifyClient - completion_client = CompletionClient(self.api_key, self.base_url) - self.assertIsInstance(completion_client, DifyClient) - - # WorkflowClient inherits from DifyClient - workflow_client = WorkflowClient(self.api_key, self.base_url) - self.assertIsInstance(workflow_client, DifyClient) - - # Clean up - chat_client.close() - completion_client.close() - workflow_client.close() - - @patch("dify_client.client.httpx.Client") - def test_nested_context_managers(self, mock_httpx_client): - """Test nested context managers work correctly.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client1: - with ChatClient(self.api_key, self.base_url) as client2: - self.assertEqual(client1.api_key, self.api_key) - self.assertEqual(client2.api_key, self.api_key) - - # Both close methods should have been called - self.assertEqual(mock_client_instance.close.call_count, 2) - - -class TestChatClientHttpx(unittest.TestCase): - """Test ChatClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_httpx(self, mock_httpx_client): - """Test create_chat_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json.return_value = {"answer": "Hello!"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with ChatClient("test-key") as client: - response = client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - self.assertEqual(response.json()["answer"], "Hello!") - - -class TestCompletionClientHttpx(unittest.TestCase): - """Test CompletionClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_httpx(self, mock_httpx_client): - """Test create_completion_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Response"}' - mock_response.json.return_value = {"answer": "Response"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with CompletionClient("test-key") as client: - response = client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - -class TestKnowledgeBaseClientHttpx(unittest.TestCase): - """Test KnowledgeBaseClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_list_datasets_httpx(self, mock_httpx_client): - """Test list_datasets works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with KnowledgeBaseClient("test-key") as client: - response = client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - -class TestWorkflowClientHttpx(unittest.TestCase): - """Test WorkflowClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_run_workflow_httpx(self, mock_httpx_client): - """Test run workflow works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkflowClient("test-key") as client: - response = client.run({"input": "test"}, "blocking", "user123") - self.assertEqual(response.json()["result"], "success") - - -class TestWorkspaceClientHttpx(unittest.TestCase): - """Test WorkspaceClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_get_available_models_httpx(self, mock_httpx_client): - """Test get_available_models works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkspaceClient("test-key") as client: - response = client.get_available_models("llm") - self.assertIn("data", response.json()) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_integration.py b/sdks/python-client/tests/test_integration.py deleted file mode 100644 index 6f38c5de56..0000000000 --- a/sdks/python-client/tests/test_integration.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Integration tests with proper mocking.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import json -import httpx -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - KnowledgeBaseClient, - WorkspaceClient, -) -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, -) - - -class TestDifyClientIntegration(unittest.TestCase): - """Integration tests for DifyClient with mocked HTTP responses.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient(api_key=self.api_key, base_url=self.base_url, enable_logging=False) - - @patch("httpx.Client.request") - def test_get_app_info_integration(self, mock_request): - """Test get_app_info integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "app_123", - "name": "Test App", - "description": "A test application", - "mode": "chat", - } - mock_request.return_value = mock_response - - response = self.client.get_app_info() - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "app_123") - self.assertEqual(data["name"], "Test App") - mock_request.assert_called_once_with( - "GET", - "/info", - json=None, - params=None, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_get_application_parameters_integration(self, mock_request): - """Test get_application_parameters integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "opening_statement": "Hello! How can I help you?", - "suggested_questions": ["What is AI?", "How does this work?"], - "speech_to_text": {"enabled": True}, - "text_to_speech": {"enabled": False}, - } - mock_request.return_value = mock_response - - response = self.client.get_application_parameters("user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["opening_statement"], "Hello! How can I help you?") - self.assertEqual(len(data["suggested_questions"]), 2) - mock_request.assert_called_once_with( - "GET", - "/parameters", - json=None, - params={"user": "user_123"}, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_file_upload_integration(self, mock_request): - """Test file_upload integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "file_123", - "name": "test.txt", - "size": 1024, - "mime_type": "text/plain", - } - mock_request.return_value = mock_response - - files = {"file": ("test.txt", "test content", "text/plain")} - response = self.client.file_upload("user_123", files) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "file_123") - self.assertEqual(data["name"], "test.txt") - - @patch("httpx.Client.request") - def test_message_feedback_integration(self, mock_request): - """Test message_feedback integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"success": True} - mock_request.return_value = mock_response - - response = self.client.message_feedback("msg_123", "like", "user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertTrue(data["success"]) - mock_request.assert_called_once_with( - "POST", - "/messages/msg_123/feedbacks", - json={"rating": "like", "user": "user_123"}, - params=None, - headers={ - "Authorization": "Bearer test_api_key", - "Content-Type": "application/json", - }, - ) - - -class TestChatClientIntegration(unittest.TestCase): - """Integration tests for ChatClient.""" - - def setUp(self): - self.client = ChatClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_chat_message_blocking(self, mock_request): - """Test create_chat_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "msg_123", - "answer": "Hello! How can I help you today?", - "conversation_id": "conv_123", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_chat_message( - inputs={"query": "Hello"}, - query="Hello, AI!", - user="user_123", - response_mode="blocking", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "Hello! How can I help you today?") - self.assertEqual(data["conversation_id"], "conv_123") - - @patch("httpx.Client.request") - def test_create_chat_message_streaming(self, mock_request): - """Test create_chat_message with streaming response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.iter_lines.return_value = [ - b'data: {"answer": "Hello"}', - b'data: {"answer": " world"}', - b'data: {"answer": "!"}', - ] - mock_request.return_value = mock_response - - response = self.client.create_chat_message(inputs={}, query="Hello", user="user_123", response_mode="streaming") - - self.assertEqual(response.status_code, 200) - lines = list(response.iter_lines()) - self.assertEqual(len(lines), 3) - - @patch("httpx.Client.request") - def test_get_conversations_integration(self, mock_request): - """Test get_conversations integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "conv_1", "name": "Conversation 1"}, - {"id": "conv_2", "name": "Conversation 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_conversations("user_123", limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["name"], "Conversation 1") - - @patch("httpx.Client.request") - def test_get_conversation_messages_integration(self, mock_request): - """Test get_conversation_messages integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "msg_1", "role": "user", "content": "Hello"}, - {"id": "msg_2", "role": "assistant", "content": "Hi there!"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_conversation_messages("user_123", conversation_id="conv_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["role"], "user") - - -class TestCompletionClientIntegration(unittest.TestCase): - """Integration tests for CompletionClient.""" - - def setUp(self): - self.client = CompletionClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_completion_message_blocking(self, mock_request): - """Test create_completion_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_123", - "answer": "This is a completion response.", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_completion_message( - inputs={"prompt": "Complete this sentence"}, - response_mode="blocking", - user="user_123", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "This is a completion response.") - - @patch("httpx.Client.request") - def test_create_completion_message_with_files(self, mock_request): - """Test create_completion_message with files.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_124", - "answer": "I can see the image shows...", - "files": [{"id": "file_1", "type": "image"}], - } - mock_request.return_value = mock_response - - files = { - "file": { - "type": "image", - "transfer_method": "remote_url", - "url": "https://example.com/image.jpg", - } - } - response = self.client.create_completion_message( - inputs={"prompt": "Describe this image"}, - response_mode="blocking", - user="user_123", - files=files, - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertIn("image", data["answer"]) - self.assertEqual(len(data["files"]), 1) - - -class TestWorkflowClientIntegration(unittest.TestCase): - """Integration tests for WorkflowClient.""" - - def setUp(self): - self.client = WorkflowClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_run_workflow_blocking(self, mock_request): - """Test run workflow with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "run_123", - "workflow_id": "workflow_123", - "status": "succeeded", - "inputs": {"query": "Test input"}, - "outputs": {"result": "Test output"}, - "elapsed_time": 2.5, - } - mock_request.return_value = mock_response - - response = self.client.run(inputs={"query": "Test input"}, response_mode="blocking", user="user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["status"], "succeeded") - self.assertEqual(data["outputs"]["result"], "Test output") - - @patch("httpx.Client.request") - def test_get_workflow_logs(self, mock_request): - """Test get_workflow_logs integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "logs": [ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - "total": 2, - "page": 1, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_workflow_logs(page=1, limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["logs"]), 2) - self.assertEqual(data["logs"][0]["status"], "succeeded") - - -class TestKnowledgeBaseClientIntegration(unittest.TestCase): - """Integration tests for KnowledgeBaseClient.""" - - def setUp(self): - self.client = KnowledgeBaseClient("test_api_key") - - @patch("httpx.Client.request") - def test_create_dataset(self, mock_request): - """Test create_dataset integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "dataset_123", - "name": "Test Dataset", - "description": "A test dataset", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_dataset(name="Test Dataset") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["name"], "Test Dataset") - self.assertEqual(data["id"], "dataset_123") - - @patch("httpx.Client.request") - def test_list_datasets(self, mock_request): - """Test list_datasets integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "dataset_1", "name": "Dataset 1"}, - {"id": "dataset_2", "name": "Dataset 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.list_datasets(page=1, page_size=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - - @patch("httpx.Client.request") - def test_create_document_by_text(self, mock_request): - """Test create_document_by_text integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "document": { - "id": "doc_123", - "name": "Test Document", - "word_count": 100, - "status": "indexing", - } - } - mock_request.return_value = mock_response - - # Mock dataset_id - self.client.dataset_id = "dataset_123" - - response = self.client.create_document_by_text(name="Test Document", text="This is test document content.") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["document"]["name"], "Test Document") - self.assertEqual(data["document"]["word_count"], 100) - - -class TestWorkspaceClientIntegration(unittest.TestCase): - """Integration tests for WorkspaceClient.""" - - def setUp(self): - self.client = WorkspaceClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_get_available_models(self, mock_request): - """Test get_available_models integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_available_models("llm") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["models"]), 2) - self.assertEqual(data["models"][0]["id"], "gpt-4") - - -class TestErrorScenariosIntegration(unittest.TestCase): - """Integration tests for error scenarios.""" - - def setUp(self): - self.client = DifyClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error_integration(self, mock_request): - """Test authentication error in integration.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error_integration(self, mock_request): - """Test rate limit error in integration.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_server_error_with_retry_integration(self, mock_request): - """Test server error with retry in integration.""" - # API errors don't retry by design - only network/timeout errors retry - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with patch("time.sleep"): # Skip actual sleep - with self.assertRaises(APIError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_validation_error_integration(self, mock_request): - """Test validation error in integration.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = { - "message": "Validation failed", - "details": {"field": "query", "error": "required"}, - } - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Validation failed") - self.assertEqual(context.exception.status_code, 422) - - -class TestContextManagerIntegration(unittest.TestCase): - """Integration tests for context manager usage.""" - - @patch("httpx.Client.close") - @patch("httpx.Client.request") - def test_context_manager_usage(self, mock_request, mock_close): - """Test context manager properly closes connections.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"id": "app_123", "name": "Test App"} - mock_request.return_value = mock_response - - with DifyClient("test_api_key") as client: - response = client.get_app_info() - self.assertEqual(response.status_code, 200) - - # Verify close was called - mock_close.assert_called_once() - - @patch("httpx.Client.close") - def test_manual_close(self, mock_close): - """Test manual close method.""" - client = DifyClient("test_api_key") - client.close() - mock_close.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_models.py b/sdks/python-client/tests/test_models.py deleted file mode 100644 index db9d92ad5b..0000000000 --- a/sdks/python-client/tests/test_models.py +++ /dev/null @@ -1,640 +0,0 @@ -"""Unit tests for response models.""" - -import unittest -import json -from datetime import datetime -from dify_client.models import ( - BaseResponse, - ErrorResponse, - FileInfo, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -) - - -class TestResponseModels(unittest.TestCase): - """Test cases for response model classes.""" - - def test_base_response(self): - """Test BaseResponse model.""" - response = BaseResponse(success=True, message="Operation successful") - self.assertTrue(response.success) - self.assertEqual(response.message, "Operation successful") - - def test_base_response_defaults(self): - """Test BaseResponse with default values.""" - response = BaseResponse(success=True) - self.assertTrue(response.success) - self.assertIsNone(response.message) - - def test_error_response(self): - """Test ErrorResponse model.""" - response = ErrorResponse( - success=False, - message="Error occurred", - error_code="VALIDATION_ERROR", - details={"field": "invalid_value"}, - ) - self.assertFalse(response.success) - self.assertEqual(response.message, "Error occurred") - self.assertEqual(response.error_code, "VALIDATION_ERROR") - self.assertEqual(response.details["field"], "invalid_value") - - def test_file_info(self): - """Test FileInfo model.""" - now = datetime.now() - file_info = FileInfo( - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/file.txt", - created_at=now, - ) - self.assertEqual(file_info.id, "file_123") - self.assertEqual(file_info.name, "test.txt") - self.assertEqual(file_info.size, 1024) - self.assertEqual(file_info.mime_type, "text/plain") - self.assertEqual(file_info.url, "https://example.com/file.txt") - self.assertEqual(file_info.created_at, now) - - def test_message_response(self): - """Test MessageResponse model.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - created_at=1234567890, - metadata={"model": "gpt-4"}, - files=[{"id": "file_1", "type": "image"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "msg_123") - self.assertEqual(response.answer, "Hello, world!") - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["model"], "gpt-4") - self.assertEqual(response.files[0]["id"], "file_1") - - def test_conversation_response(self): - """Test ConversationResponse model.""" - response = ConversationResponse( - success=True, - id="conv_123", - name="Test Conversation", - inputs={"query": "Hello"}, - status="active", - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "conv_123") - self.assertEqual(response.name, "Test Conversation") - self.assertEqual(response.inputs["query"], "Hello") - self.assertEqual(response.status, "active") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_response(self): - """Test DatasetResponse model.""" - response = DatasetResponse( - success=True, - id="dataset_123", - name="Test Dataset", - description="A test dataset", - permission="read", - indexing_technique="high_quality", - embedding_model="text-embedding-ada-002", - embedding_model_provider="openai", - retrieval_model={"search_type": "semantic"}, - document_count=10, - word_count=5000, - app_count=2, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "dataset_123") - self.assertEqual(response.name, "Test Dataset") - self.assertEqual(response.description, "A test dataset") - self.assertEqual(response.permission, "read") - self.assertEqual(response.indexing_technique, "high_quality") - self.assertEqual(response.embedding_model, "text-embedding-ada-002") - self.assertEqual(response.embedding_model_provider, "openai") - self.assertEqual(response.retrieval_model["search_type"], "semantic") - self.assertEqual(response.document_count, 10) - self.assertEqual(response.word_count, 5000) - self.assertEqual(response.app_count, 2) - - def test_document_response(self): - """Test DocumentResponse model.""" - response = DocumentResponse( - success=True, - id="doc_123", - name="test_document.txt", - data_source_type="upload_file", - position=1, - enabled=True, - word_count=1000, - hit_count=5, - doc_form="text_model", - created_at=1234567890.0, - indexing_status="completed", - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "doc_123") - self.assertEqual(response.name, "test_document.txt") - self.assertEqual(response.data_source_type, "upload_file") - self.assertEqual(response.position, 1) - self.assertTrue(response.enabled) - self.assertEqual(response.word_count, 1000) - self.assertEqual(response.hit_count, 5) - self.assertEqual(response.doc_form, "text_model") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.indexing_status, "completed") - self.assertEqual(response.completed_at, 1234567891.0) - - def test_document_segment_response(self): - """Test DocumentSegmentResponse model.""" - response = DocumentSegmentResponse( - success=True, - id="seg_123", - position=1, - document_id="doc_123", - content="This is a test segment.", - answer="Test answer", - word_count=5, - tokens=10, - keywords=["test", "segment"], - hit_count=2, - enabled=True, - status="completed", - created_at=1234567890.0, - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "seg_123") - self.assertEqual(response.position, 1) - self.assertEqual(response.document_id, "doc_123") - self.assertEqual(response.content, "This is a test segment.") - self.assertEqual(response.answer, "Test answer") - self.assertEqual(response.word_count, 5) - self.assertEqual(response.tokens, 10) - self.assertEqual(response.keywords, ["test", "segment"]) - self.assertEqual(response.hit_count, 2) - self.assertTrue(response.enabled) - self.assertEqual(response.status, "completed") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.completed_at, 1234567891.0) - - def test_workflow_run_response(self): - """Test WorkflowRunResponse model.""" - response = WorkflowRunResponse( - success=True, - id="run_123", - workflow_id="workflow_123", - status="succeeded", - inputs={"query": "test"}, - outputs={"answer": "result"}, - elapsed_time=5.5, - total_tokens=100, - total_steps=3, - created_at=1234567890.0, - finished_at=1234567895.5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "run_123") - self.assertEqual(response.workflow_id, "workflow_123") - self.assertEqual(response.status, "succeeded") - self.assertEqual(response.inputs["query"], "test") - self.assertEqual(response.outputs["answer"], "result") - self.assertEqual(response.elapsed_time, 5.5) - self.assertEqual(response.total_tokens, 100) - self.assertEqual(response.total_steps, 3) - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.finished_at, 1234567895.5) - - def test_application_parameters_response(self): - """Test ApplicationParametersResponse model.""" - response = ApplicationParametersResponse( - success=True, - opening_statement="Hello! How can I help you?", - suggested_questions=["What is AI?", "How does this work?"], - speech_to_text={"enabled": True}, - text_to_speech={"enabled": False, "voice": "alloy"}, - retriever_resource={"enabled": True}, - sensitive_word_avoidance={"enabled": False}, - file_upload={"enabled": True, "file_size_limit": 10485760}, - system_parameters={"max_tokens": 1000}, - user_input_form=[{"type": "text", "label": "Query"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.opening_statement, "Hello! How can I help you?") - self.assertEqual(response.suggested_questions, ["What is AI?", "How does this work?"]) - self.assertTrue(response.speech_to_text["enabled"]) - self.assertFalse(response.text_to_speech["enabled"]) - self.assertEqual(response.text_to_speech["voice"], "alloy") - self.assertTrue(response.retriever_resource["enabled"]) - self.assertFalse(response.sensitive_word_avoidance["enabled"]) - self.assertTrue(response.file_upload["enabled"]) - self.assertEqual(response.file_upload["file_size_limit"], 10485760) - self.assertEqual(response.system_parameters["max_tokens"], 1000) - self.assertEqual(response.user_input_form[0]["type"], "text") - - def test_annotation_response(self): - """Test AnnotationResponse model.""" - response = AnnotationResponse( - success=True, - id="annotation_123", - question="What is the capital of France?", - answer="Paris", - content="Additional context", - created_at=1234567890.0, - updated_at=1234567891.0, - created_by="user_123", - updated_by="user_123", - hit_count=5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "annotation_123") - self.assertEqual(response.question, "What is the capital of France?") - self.assertEqual(response.answer, "Paris") - self.assertEqual(response.content, "Additional context") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.updated_at, 1234567891.0) - self.assertEqual(response.created_by, "user_123") - self.assertEqual(response.updated_by, "user_123") - self.assertEqual(response.hit_count, 5) - - def test_paginated_response(self): - """Test PaginatedResponse model.""" - response = PaginatedResponse( - success=True, - data=[{"id": 1}, {"id": 2}, {"id": 3}], - has_more=True, - limit=10, - total=100, - page=1, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]["id"], 1) - self.assertTrue(response.has_more) - self.assertEqual(response.limit, 10) - self.assertEqual(response.total, 100) - self.assertEqual(response.page, 1) - - def test_conversation_variable_response(self): - """Test ConversationVariableResponse model.""" - response = ConversationVariableResponse( - success=True, - conversation_id="conv_123", - variables=[ - {"id": "var_1", "name": "user_name", "value": "John"}, - {"id": "var_2", "name": "preferences", "value": {"theme": "dark"}}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(len(response.variables), 2) - self.assertEqual(response.variables[0]["name"], "user_name") - self.assertEqual(response.variables[0]["value"], "John") - self.assertEqual(response.variables[1]["name"], "preferences") - self.assertEqual(response.variables[1]["value"]["theme"], "dark") - - def test_file_upload_response(self): - """Test FileUploadResponse model.""" - response = FileUploadResponse( - success=True, - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/files/test.txt", - created_at=1234567890.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "test.txt") - self.assertEqual(response.size, 1024) - self.assertEqual(response.mime_type, "text/plain") - self.assertEqual(response.url, "https://example.com/files/test.txt") - self.assertEqual(response.created_at, 1234567890.0) - - def test_audio_response(self): - """Test AudioResponse model.""" - response = AudioResponse( - success=True, - audio="base64_encoded_audio_data", - audio_url="https://example.com/audio.mp3", - duration=10.5, - sample_rate=44100, - ) - self.assertTrue(response.success) - self.assertEqual(response.audio, "base64_encoded_audio_data") - self.assertEqual(response.audio_url, "https://example.com/audio.mp3") - self.assertEqual(response.duration, 10.5) - self.assertEqual(response.sample_rate, 44100) - - def test_suggested_questions_response(self): - """Test SuggestedQuestionsResponse model.""" - response = SuggestedQuestionsResponse( - success=True, - message_id="msg_123", - questions=[ - "What is machine learning?", - "How does AI work?", - "Can you explain neural networks?", - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.message_id, "msg_123") - self.assertEqual(len(response.questions), 3) - self.assertEqual(response.questions[0], "What is machine learning?") - - def test_app_info_response(self): - """Test AppInfoResponse model.""" - response = AppInfoResponse( - success=True, - id="app_123", - name="Test App", - description="A test application", - icon="🤖", - icon_background="#FF6B6B", - mode="chat", - tags=["AI", "Chat", "Test"], - enable_site=True, - enable_api=True, - api_token="app_token_123", - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "app_123") - self.assertEqual(response.name, "Test App") - self.assertEqual(response.description, "A test application") - self.assertEqual(response.icon, "🤖") - self.assertEqual(response.icon_background, "#FF6B6B") - self.assertEqual(response.mode, "chat") - self.assertEqual(response.tags, ["AI", "Chat", "Test"]) - self.assertTrue(response.enable_site) - self.assertTrue(response.enable_api) - self.assertEqual(response.api_token, "app_token_123") - - def test_workspace_models_response(self): - """Test WorkspaceModelsResponse model.""" - response = WorkspaceModelsResponse( - success=True, - models=[ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertEqual(response.models[0]["name"], "GPT-4") - self.assertEqual(response.models[0]["provider"], "openai") - - def test_hit_testing_response(self): - """Test HitTestingResponse model.""" - response = HitTestingResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is a subset of AI...", "score": 0.95}, - {"content": "ML algorithms learn from data...", "score": 0.87}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.records[0]["score"], 0.95) - - def test_dataset_tags_response(self): - """Test DatasetTagsResponse model.""" - response = DatasetTagsResponse( - success=True, - tags=[ - {"id": "tag_1", "name": "Technology", "color": "#FF0000"}, - {"id": "tag_2", "name": "Science", "color": "#00FF00"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.tags), 2) - self.assertEqual(response.tags[0]["name"], "Technology") - self.assertEqual(response.tags[0]["color"], "#FF0000") - - def test_workflow_logs_response(self): - """Test WorkflowLogsResponse model.""" - response = WorkflowLogsResponse( - success=True, - logs=[ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - total=50, - page=1, - limit=10, - has_more=True, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.logs), 2) - self.assertEqual(response.logs[0]["status"], "succeeded") - self.assertEqual(response.total, 50) - self.assertEqual(response.page, 1) - self.assertEqual(response.limit, 10) - self.assertTrue(response.has_more) - - def test_model_serialization(self): - """Test that models can be serialized to JSON.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - ) - - # Convert to dict and then to JSON - response_dict = { - "success": response.success, - "id": response.id, - "answer": response.answer, - "conversation_id": response.conversation_id, - } - - json_str = json.dumps(response_dict) - parsed = json.loads(json_str) - - self.assertTrue(parsed["success"]) - self.assertEqual(parsed["id"], "msg_123") - self.assertEqual(parsed["answer"], "Hello, world!") - self.assertEqual(parsed["conversation_id"], "conv_123") - - # Tests for new response models - def test_model_provider_response(self): - """Test ModelProviderResponse model.""" - response = ModelProviderResponse( - success=True, - provider_name="openai", - provider_type="llm", - models=[ - {"id": "gpt-4", "name": "GPT-4", "max_tokens": 8192}, - {"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo", "max_tokens": 4096}, - ], - is_enabled=True, - credentials={"api_key": "sk-..."}, - ) - self.assertTrue(response.success) - self.assertEqual(response.provider_name, "openai") - self.assertEqual(response.provider_type, "llm") - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertTrue(response.is_enabled) - self.assertEqual(response.credentials["api_key"], "sk-...") - - def test_file_info_response(self): - """Test FileInfoResponse model.""" - response = FileInfoResponse( - success=True, - id="file_123", - name="document.pdf", - size=2048576, - mime_type="application/pdf", - url="https://example.com/files/document.pdf", - created_at=1234567890, - metadata={"pages": 10, "author": "John Doe"}, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "document.pdf") - self.assertEqual(response.size, 2048576) - self.assertEqual(response.mime_type, "application/pdf") - self.assertEqual(response.url, "https://example.com/files/document.pdf") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["pages"], 10) - - def test_workflow_draft_response(self): - """Test WorkflowDraftResponse model.""" - response = WorkflowDraftResponse( - success=True, - id="draft_123", - app_id="app_456", - draft_data={"nodes": [], "edges": [], "config": {"name": "Test Workflow"}}, - version=1, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "draft_123") - self.assertEqual(response.app_id, "app_456") - self.assertEqual(response.draft_data["config"]["name"], "Test Workflow") - self.assertEqual(response.version, 1) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_api_token_response(self): - """Test ApiTokenResponse model.""" - response = ApiTokenResponse( - success=True, - id="token_123", - name="Production Token", - token="app-xxxxxxxxxxxx", - description="Token for production environment", - created_at=1234567890, - last_used_at=1234567891, - is_active=True, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "token_123") - self.assertEqual(response.name, "Production Token") - self.assertEqual(response.token, "app-xxxxxxxxxxxx") - self.assertEqual(response.description, "Token for production environment") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.last_used_at, 1234567891) - self.assertTrue(response.is_active) - - def test_job_status_response(self): - """Test JobStatusResponse model.""" - response = JobStatusResponse( - success=True, - job_id="job_123", - job_status="running", - error_msg=None, - progress=0.75, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.job_id, "job_123") - self.assertEqual(response.job_status, "running") - self.assertIsNone(response.error_msg) - self.assertEqual(response.progress, 0.75) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_query_response(self): - """Test DatasetQueryResponse model.""" - response = DatasetQueryResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is...", "score": 0.95}, - {"content": "ML algorithms...", "score": 0.87}, - ], - total=2, - search_time=0.123, - retrieval_model={"method": "semantic_search", "top_k": 3}, - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.total, 2) - self.assertEqual(response.search_time, 0.123) - self.assertEqual(response.retrieval_model["method"], "semantic_search") - - def test_dataset_template_response(self): - """Test DatasetTemplateResponse model.""" - response = DatasetTemplateResponse( - success=True, - template_name="customer_support", - display_name="Customer Support", - description="Template for customer support knowledge base", - category="support", - icon="🎧", - config_schema={"fields": [{"name": "category", "type": "string"}]}, - ) - self.assertTrue(response.success) - self.assertEqual(response.template_name, "customer_support") - self.assertEqual(response.display_name, "Customer Support") - self.assertEqual(response.description, "Template for customer support knowledge base") - self.assertEqual(response.category, "support") - self.assertEqual(response.icon, "🎧") - self.assertEqual(response.config_schema["fields"][0]["name"], "category") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_retry_and_error_handling.py b/sdks/python-client/tests/test_retry_and_error_handling.py deleted file mode 100644 index bd415bde43..0000000000 --- a/sdks/python-client/tests/test_retry_and_error_handling.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Unit tests for retry mechanism and error handling.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import httpx -from dify_client.client import DifyClient -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, -) - - -class TestRetryMechanism(unittest.TestCase): - """Test cases for retry mechanism.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient( - api_key=self.api_key, - base_url=self.base_url, - max_retries=3, - retry_delay=0.1, # Short delay for tests - enable_logging=False, - ) - - @patch("httpx.Client.request") - def test_successful_request_no_retry(self, mock_request): - """Test that successful requests don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - mock_request.return_value = mock_response - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response, mock_response) - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_network_error(self, mock_sleep, mock_request): - """Test retry on network errors.""" - # First two calls raise network error, third succeeds - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - Mock(status_code=200, content=b'{"success": true}'), - ] - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_timeout_error(self, mock_sleep, mock_request): - """Test retry on timeout errors.""" - mock_request.side_effect = [ - httpx.TimeoutException("Request timed out"), - httpx.TimeoutException("Request timed out"), - Mock(status_code=200, content=b'{"success": true}'), - ] - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_max_retries_exceeded(self, mock_sleep, mock_request): - """Test behavior when max retries are exceeded.""" - mock_request.side_effect = httpx.NetworkError("Persistent network error") - - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 4) # 1 initial + 3 retries - self.assertEqual(mock_sleep.call_count, 3) - - @patch("httpx.Client.request") - def test_no_retry_on_client_error(self, mock_request): - """Test that client errors (4xx) don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Unauthorized"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_retry_on_server_error(self, mock_request): - """Test that server errors (5xx) don't retry - they raise APIError immediately.""" - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - # Should not retry server errors - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_exponential_backoff(self, mock_request): - """Test exponential backoff timing.""" - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), # All attempts fail - ] - - with patch("time.sleep") as mock_sleep: - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - # Check exponential backoff: 0.1, 0.2, 0.4 - expected_calls = [0.1, 0.2, 0.4] - actual_calls = [call[0][0] for call in mock_sleep.call_args_list] - self.assertEqual(actual_calls, expected_calls) - - -class TestErrorHandling(unittest.TestCase): - """Test cases for error handling.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error(self, mock_request): - """Test AuthenticationError handling.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error(self, mock_request): - """Test RateLimitError handling.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_validation_error(self, mock_request): - """Test ValidationError handling.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = {"message": "Invalid parameters"} - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid parameters") - self.assertEqual(context.exception.status_code, 422) - - @patch("httpx.Client.request") - def test_api_error(self, mock_request): - """Test general APIError handling.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.json.return_value = {"message": "Internal server error"} - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - - @patch("httpx.Client.request") - def test_error_response_without_json(self, mock_request): - """Test error handling when response doesn't contain valid JSON.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.content = b"Internal Server Error" - mock_response.json.side_effect = ValueError("No JSON object could be decoded") - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "HTTP 500") - - @patch("httpx.Client.request") - def test_file_upload_error(self, mock_request): - """Test FileUploadError handling.""" - mock_response = Mock() - mock_response.status_code = 400 - mock_response.json.return_value = {"message": "File upload failed"} - mock_request.return_value = mock_response - - with self.assertRaises(FileUploadError) as context: - self.client._send_request_with_files("POST", "/upload", {}, {}) - - self.assertEqual(str(context.exception), "File upload failed") - self.assertEqual(context.exception.status_code, 400) - - -class TestParameterValidation(unittest.TestCase): - """Test cases for parameter validation.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - def test_empty_string_validation(self): - """Test validation of empty strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(empty_string="") - - def test_whitespace_only_string_validation(self): - """Test validation of whitespace-only strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(whitespace_string=" ") - - def test_long_string_validation(self): - """Test validation of overly long strings.""" - long_string = "a" * 10001 # Exceeds 10000 character limit - with self.assertRaises(ValidationError): - self.client._validate_params(long_string=long_string) - - def test_large_list_validation(self): - """Test validation of overly large lists.""" - large_list = list(range(1001)) # Exceeds 1000 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_list=large_list) - - def test_large_dict_validation(self): - """Test validation of overly large dictionaries.""" - large_dict = {f"key_{i}": i for i in range(101)} # Exceeds 100 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_dict=large_dict) - - def test_valid_parameters_pass(self): - """Test that valid parameters pass validation.""" - # Should not raise any exception - self.client._validate_params( - valid_string="Hello, World!", - valid_list=[1, 2, 3], - valid_dict={"key": "value"}, - none_value=None, - ) - - def test_message_feedback_validation(self): - """Test validation in message_feedback method.""" - with self.assertRaises(ValidationError): - self.client.message_feedback("msg_id", "invalid_rating", "user") - - def test_completion_message_validation(self): - """Test validation in create_completion_message method.""" - from dify_client.client import CompletionClient - - client = CompletionClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_completion_message( - inputs="not_a_dict", # Should be a dict - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - user="test_user", - ) - - def test_chat_message_validation(self): - """Test validation in create_chat_message method.""" - from dify_client.client import ChatClient - - client = ChatClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_chat_message( - inputs="not_a_dict", # Should be a dict - query="", # Should not be empty - user="test_user", - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/uv.lock b/sdks/python-client/uv.lock deleted file mode 100644 index 4a9d7d5193..0000000000 --- a/sdks/python-client/uv.lock +++ /dev/null @@ -1,307 +0,0 @@ -version = 1 -revision = 3 -requires-python = ">=3.10" - -[[package]] -name = "aiofiles" -version = "25.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" }, -] - -[[package]] -name = "anyio" -version = "4.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, -] - -[[package]] -name = "backports-asyncio-runner" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, -] - -[[package]] -name = "certifi" -version = "2025.10.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, -] - -[[package]] -name = "dify-client" -version = "0.1.12" -source = { editable = "." } -dependencies = [ - { name = "aiofiles" }, - { name = "httpx", extra = ["http2"] }, -] - -[package.optional-dependencies] -dev = [ - { name = "pytest" }, - { name = "pytest-asyncio" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiofiles", specifier = ">=23.0.0" }, - { name = "httpx", extras = ["http2"], specifier = ">=0.27.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, -] -provides-extras = ["dev"] - -[[package]] -name = "exceptiongroup" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, -] - -[[package]] -name = "h2" -version = "4.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "hpack" }, - { name = "hyperframe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, -] - -[[package]] -name = "hpack" -version = "4.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, -] - -[package.optional-dependencies] -http2 = [ - { name = "h2" }, -] - -[[package]] -name = "hyperframe" -version = "6.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, -] - -[[package]] -name = "iniconfig" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, -] - -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, -] - -[[package]] -name = "pygments" -version = "2.19.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, -] - -[[package]] -name = "pytest" -version = "8.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "pygments" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, -] - -[[package]] -name = "pytest-asyncio" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, - { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, -] - -[[package]] -name = "tomli" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, - { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, - { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, - { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, - { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, - { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, - { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, - { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, - { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, - { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, - { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, - { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, - { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, - { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, - { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, - { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, - { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, - { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, - { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, - { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, - { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, - { url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" }, - { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, - { url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" }, - { url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" }, - { url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" }, - { url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" }, - { url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" }, - { url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" }, - { url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" }, - { url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" }, - { url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" }, - { url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" }, - { url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" }, - { url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" }, - { url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" }, - { url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" }, - { url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" }, - { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, -] - -[[package]] -name = "typing-extensions" -version = "4.15.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, -] diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts new file mode 100644 index 0000000000..594fe38f14 --- /dev/null +++ b/web/__mocks__/provider-context.ts @@ -0,0 +1,47 @@ +import { merge, noop } from 'lodash-es' +import { defaultPlan } from '@/app/components/billing/config' +import { baseProviderContextValue } from '@/context/provider-context' +import type { ProviderContextState } from '@/context/provider-context' +import type { Plan, UsagePlanInfo } from '@/app/components/billing/type' + +export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { + const merged = merge({}, baseProviderContextValue, overrides) + + return { + ...merged, + refreshModelProviders: merged.refreshModelProviders ?? noop, + onPlanInfoChanged: merged.onPlanInfoChanged ?? noop, + refreshLicenseLimit: merged.refreshLicenseLimit ?? noop, + } +} + +export const createMockPlan = (plan: Plan): ProviderContextState => + createMockProviderContextValue({ + plan: merge({}, defaultPlan, { + type: plan, + }), + }) + +export const createMockPlanUsage = (usage: UsagePlanInfo, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx.plan, { + usage, + }), + }) + +export const createMockPlanTotal = (total: UsagePlanInfo, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx.plan, { + total, + }), + }) + +export const createMockPlanReset = (reset: Partial, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx?.plan, { + reset, + }), + }) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index b329c1a113..576244c0d4 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -32,6 +32,7 @@ import { canFindTool } from '@/utils' import { useAllBuiltInTools, useAllCustomTools, useAllMCPTools, useAllWorkflowTools } from '@/service/use-tools' import type { ToolWithProvider } from '@/app/components/workflow/types' import { useMittContextSelector } from '@/context/mitt-context' +import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' type AgentToolWithMoreInfo = (AgentTool & { icon: any; collection?: Collection; use_end_user_credentials?: boolean; end_user_credential_type?: string }) | null @@ -101,13 +102,17 @@ const AgentTools: FC = () => { })) }, []) const getToolValue = (tool: ToolDefaultValue) => { + const currToolInCollections = collectionList.find(c => c.id === tool.provider_id) + const currToolWithConfigs = currToolInCollections?.tools.find(t => t.name === tool.tool_name) + const formSchemas = currToolWithConfigs ? toolParametersToFormSchemas(currToolWithConfigs.parameters) : [] + const paramsWithDefaultValue = addDefaultValue(tool.params, formSchemas) return { provider_id: tool.provider_id, provider_type: tool.provider_type as CollectionType, provider_name: tool.provider_name, tool_name: tool.tool_name, tool_label: tool.tool_label, - tool_parameters: tool.params, + tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, use_end_user_credentials: false, @@ -129,7 +134,7 @@ const AgentTools: FC = () => { } const getProviderShowName = (item: AgentTool) => { const type = item.provider_type - if(type === CollectionType.builtIn) + if (type === CollectionType.builtIn) return item.provider_name.split('/').pop() return item.provider_name } diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index aaa4d5830e..a3decb4b04 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -16,7 +16,7 @@ import Description from '@/app/components/plugins/card/base/description' import TabSlider from '@/app/components/base/tab-slider-plain' import Button from '@/app/components/base/button' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' -import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import type { Collection, Tool } from '@/app/components/tools/types' import { CollectionType } from '@/app/components/tools/types' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList, fetchWorkflowToolList } from '@/service/tools' @@ -100,15 +100,11 @@ const SettingBuiltInTool: FC = ({ }()) }) setTools(list) - const currTool = list.find(tool => tool.name === toolName) - if (currTool) { - const formSchemas = toolParametersToFormSchemas(currTool.parameters) - setTempSetting(addDefaultValue(setting, formSchemas)) - } } catch { } setIsLoading(false) })() + // eslint-disable-next-line react-hooks/exhaustive-deps }, [collection?.name, collection?.id, collection?.type]) useEffect(() => { @@ -261,7 +257,7 @@ const SettingBuiltInTool: FC = ({ {!readonly && !isInfoActive && (
- +
)} diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index bf81858565..44a54f8e8b 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -77,7 +77,7 @@ const DatasetConfig: FC = () => { const oldRetrievalConfig = { top_k, score_threshold, - reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + reranking_model: (reranking_model && reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { provider: reranking_model.reranking_provider_name, model: reranking_model.reranking_model_name, } : undefined, diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index feb7a38165..6857c38e1e 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -1,18 +1,20 @@ 'use client' import type { FC } from 'react' -import React, { useRef, useState } from 'react' -import { useGetState, useInfiniteScroll } from 'ahooks' +import React, { useEffect, useMemo, useRef, useState } from 'react' +import { useInfiniteScroll } from 'ahooks' import { useTranslation } from 'react-i18next' import Link from 'next/link' import Modal from '@/app/components/base/modal' import type { DataSet } from '@/models/datasets' import Button from '@/app/components/base/button' -import { fetchDatasets } from '@/service/datasets' import Loading from '@/app/components/base/loading' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' import cn from '@/utils/classnames' import AppIcon from '@/app/components/base/app-icon' +import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' +import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import FeatureIcon from '@/app/components/header/account-setting/model-provider-page/model-selector/feature-icon' export type ISelectDataSetProps = { isShow: boolean @@ -28,51 +30,70 @@ const SelectDataSet: FC = ({ onSelect, }) => { const { t } = useTranslation() - const [selected, setSelected] = React.useState([]) - const [loaded, setLoaded] = React.useState(false) - const [datasets, setDataSets] = React.useState(null) - const [hasInitialized, setHasInitialized] = React.useState(false) - const hasNoData = !datasets || datasets?.length === 0 + const [selected, setSelected] = useState([]) const canSelectMulti = true + const { formatIndexingTechniqueAndMethod } = useKnowledge() + const { data, isLoading, isFetchingNextPage, fetchNextPage, hasNextPage } = useInfiniteDatasets( + { page: 1 }, + { enabled: isShow, staleTime: 0, refetchOnMount: 'always' }, + ) + const pages = data?.pages || [] + const datasets = useMemo(() => { + return pages.flatMap(page => page.data.filter(item => item.indexing_technique || item.provider === 'external')) + }, [pages]) + const hasNoData = !isLoading && datasets.length === 0 const listRef = useRef(null) - const [page, setPage, getPage] = useGetState(1) - const [isNoMore, setIsNoMore] = useState(false) - const { formatIndexingTechniqueAndMethod } = useKnowledge() + const isNoMore = hasNextPage === false useInfiniteScroll( async () => { - if (!isNoMore) { - const { data, has_more } = await fetchDatasets({ url: '/datasets', params: { page } }) - setPage(getPage() + 1) - setIsNoMore(!has_more) - const newList = [...(datasets || []), ...data.filter(item => item.indexing_technique || item.provider === 'external')] - setDataSets(newList) - setLoaded(true) - - // Initialize selected datasets based on selectedIds and available datasets - if (!hasInitialized) { - if (selectedIds.length > 0) { - const validSelectedDatasets = selectedIds - .map(id => newList.find(item => item.id === id)) - .filter(Boolean) as DataSet[] - setSelected(validSelectedDatasets) - } - setHasInitialized(true) - } - } + if (!hasNextPage || isFetchingNextPage) + return { list: [] } + await fetchNextPage() return { list: [] } }, { target: listRef, - isNoMore: () => { - return isNoMore - }, - reloadDeps: [isNoMore], + isNoMore: () => isNoMore, + reloadDeps: [isNoMore, isFetchingNextPage], }, ) + const prevSelectedIdsRef = useRef([]) + const hasUserModifiedSelectionRef = useRef(false) + useEffect(() => { + if (isShow) + hasUserModifiedSelectionRef.current = false + }, [isShow]) + useEffect(() => { + const prevSelectedIds = prevSelectedIdsRef.current + const idsChanged = selectedIds.length !== prevSelectedIds.length + || selectedIds.some((id, idx) => id !== prevSelectedIds[idx]) + + if (!selectedIds.length && (!hasUserModifiedSelectionRef.current || idsChanged)) { + setSelected([]) + prevSelectedIdsRef.current = selectedIds + hasUserModifiedSelectionRef.current = false + return + } + + if (!idsChanged && hasUserModifiedSelectionRef.current) + return + + setSelected((prev) => { + const prevMap = new Map(prev.map(item => [item.id, item])) + const nextSelected = selectedIds + .map(id => datasets.find(item => item.id === id) || prevMap.get(id)) + .filter(Boolean) as DataSet[] + return nextSelected + }) + prevSelectedIdsRef.current = selectedIds + hasUserModifiedSelectionRef.current = false + }, [datasets, selectedIds]) + const toggleSelect = (dataSet: DataSet) => { + hasUserModifiedSelectionRef.current = true const isSelected = selected.some(item => item.id === dataSet.id) if (isSelected) { setSelected(selected.filter(item => item.id !== dataSet.id)) @@ -96,13 +117,13 @@ const SelectDataSet: FC = ({ className='w-[400px]' title={t('appDebug.feature.dataSet.selectTitle')} > - {!loaded && ( + {(isLoading && datasets.length === 0) && (
)} - {(loaded && hasNoData) && ( + {hasNoData && (
= ({
)} - {datasets && datasets?.length > 0 && ( + {datasets.length > 0 && ( <>
{datasets.map(item => (
i.id === item.id) && 'border-[1.5px] border-components-option-card-option-selected-border bg-state-accent-hover shadow-xs hover:border-components-option-card-option-selected-border hover:bg-state-accent-hover hover:shadow-xs', !item.embedding_available && 'hover:border-components-panel-border-subtle hover:bg-components-panel-on-panel-item-bg hover:shadow-xs', )} @@ -131,7 +152,7 @@ const SelectDataSet: FC = ({ toggleSelect(item) }} > -
+
= ({ {t('dataset.unavailable')} )}
+ {item.is_multimodal && ( +
+ +
+ )} { item.indexing_technique && ( = ({
)} - {loaded && ( + {!isLoading && (
{selected.length > 0 && `${selected.length} ${t('appDebug.feature.dataSet.selected')}`} diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index 93d0384aee..cd6e39011e 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import { useRef, useState } from 'react' +import { useMemo, useRef, useState } from 'react' import { useMount } from 'ahooks' import { useTranslation } from 'react-i18next' import { isEqual } from 'lodash-es' @@ -25,15 +25,13 @@ import { isReRankModelSelected } from '@/app/components/datasets/common/check-re import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import PermissionSelector from '@/app/components/datasets/settings/permission-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { - useModelList, - useModelListAndDefaultModelAndCurrentProviderAndModel, -} from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { fetchMembers } from '@/service/common' import type { Member } from '@/models/common' import { IndexingType } from '@/app/components/datasets/create/step-two' import { useDocLink } from '@/context/i18n' +import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils' type SettingsModalProps = { currentDataset: DataSet @@ -54,10 +52,8 @@ const SettingsModal: FC = ({ onCancel, onSave, }) => { - const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding) - const { - modelList: rerankModelList, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) const { t } = useTranslation() const docLink = useDocLink() const { notify } = useToastContext() @@ -181,6 +177,23 @@ const SettingsModal: FC = ({ getMembers() }) + const showMultiModalTip = useMemo(() => { + return checkShowMultiModalTip({ + embeddingModel: { + provider: localeCurrentDataset.embedding_model_provider, + model: localeCurrentDataset.embedding_model, + }, + rerankingEnable: retrievalConfig.reranking_enable, + rerankModel: { + rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, + rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, + }, + indexMethod, + embeddingModelList, + rerankModelList, + }) + }, [localeCurrentDataset.embedding_model, localeCurrentDataset.embedding_model_provider, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, indexMethod, embeddingModelList, rerankModelList]) + return (
= ({ provider: localeCurrentDataset.embedding_model_provider, model: localeCurrentDataset.embedding_model, }} - modelList={embeddingsModelList} + modelList={embeddingModelList} />
@@ -344,6 +357,7 @@ const SettingsModal: FC = ({ ) : ( diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index afe640278e..2537062e13 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -307,7 +307,7 @@ const Configuration: FC = () => { const oldRetrievalConfig = { top_k, score_threshold, - reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + reranking_model: (reranking_model?.reranking_provider_name && reranking_model?.reranking_model_name) ? { provider: reranking_model.reranking_provider_name, model: reranking_model.reranking_model_name, } : undefined, diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index d21d35eeee..0ff375d815 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -816,9 +816,12 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) conversationDetailMutate() notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true @@ -861,9 +864,12 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true } diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index dddd5f2526..4be6746a0e 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -352,7 +352,6 @@ const ChatWrapper = () => { themeBuilder={themeBuilder} switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} - isMobile={isMobile} sidebarCollapseState={sidebarCollapseState} questionIcon={ initUserVariables?.avatar_url diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 51b5df4f32..0e947f8137 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -71,7 +71,6 @@ export type ChatProps = { onFeatureBarClick?: (state: boolean) => void noSpacing?: boolean inputDisabled?: boolean - isMobile?: boolean sidebarCollapseState?: boolean } @@ -110,7 +109,6 @@ const Chat: FC = ({ onFeatureBarClick, noSpacing, inputDisabled, - isMobile, sidebarCollapseState, }) => { const { t } = useTranslation() @@ -321,7 +319,6 @@ const Chat: FC = ({ ) } diff --git a/web/app/components/base/chat/chat/try-to-ask.tsx b/web/app/components/base/chat/chat/try-to-ask.tsx index 7e3dcc95f9..665f7b3b13 100644 --- a/web/app/components/base/chat/chat/try-to-ask.tsx +++ b/web/app/components/base/chat/chat/try-to-ask.tsx @@ -4,28 +4,25 @@ import { useTranslation } from 'react-i18next' import type { OnSend } from '../types' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' type TryToAskProps = { suggestedQuestions: string[] onSend: OnSend - isMobile?: boolean } const TryToAsk: FC = ({ suggestedQuestions, onSend, - isMobile, }) => { const { t } = useTranslation() return (
-
- +
+
{t('appDebug.feature.suggestedQuestionsAfterAnswer.tryToAsk')}
- {!isMobile && } +
-
+
{ suggestedQuestions.map((suggestQuestion, index) => ( )}
) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index 5d7b11c7e0..396dd4a1b0 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -8,7 +8,7 @@ import { ALL_PLANS } from '../../../config' import Toast from '../../../../base/toast' import { PlanRange } from '../../plan-switcher/plan-range-switcher' import { useAppContext } from '@/context/app-context' -import { fetchSubscriptionUrls } from '@/service/billing' +import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' import List from './list' import Button from './button' import { Professional, Sandbox, Team } from '../../assets' @@ -39,7 +39,8 @@ const CloudPlanItem: FC = ({ const planInfo = ALL_PLANS[plan] const isYear = planRange === PlanRange.yearly const isCurrent = plan === currentPlan - const isPlanDisabled = planInfo.level <= ALL_PLANS[currentPlan].level + const isCurrentPaidPlan = isCurrent && !isFreePlan + const isPlanDisabled = isCurrentPaidPlan ? false : planInfo.level <= ALL_PLANS[currentPlan].level const { isCurrentWorkspaceManager } = useAppContext() const btnText = useMemo(() => { @@ -60,10 +61,6 @@ const CloudPlanItem: FC = ({ if (isPlanDisabled) return - if (isFreePlan) - return - - // Only workspace manager can buy plan if (!isCurrentWorkspaceManager) { Toast.notify({ type: 'error', @@ -74,6 +71,15 @@ const CloudPlanItem: FC = ({ } setLoading(true) try { + if (isCurrentPaidPlan) { + const res = await fetchBillingUrl() + window.open(res.url, '_blank') + return + } + + if (isFreePlan) + return + const res = await fetchSubscriptionUrls(plan, isYear ? 'year' : 'month') // Adb Block additional tracking block the gtag, so we need to redirect directly window.location.href = res.url diff --git a/web/app/components/datasets/common/document-status-with-action/index-failed.tsx b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx index 802e3d872f..4713d944e0 100644 --- a/web/app/components/datasets/common/document-status-with-action/index-failed.tsx +++ b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx @@ -2,11 +2,11 @@ import type { FC } from 'react' import React, { useEffect, useReducer } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' import StatusWithAction from './status-with-action' -import { getErrorDocs, retryErrorDocs } from '@/service/datasets' +import { retryErrorDocs } from '@/service/datasets' import type { IndexingStatusResponse } from '@/models/datasets' import { noop } from 'lodash-es' +import { useDatasetErrorDocs } from '@/service/knowledge/use-dataset' type Props = { datasetId: string @@ -35,16 +35,19 @@ const indexStateReducer = (state: IIndexState, action: IAction) => { const RetryButton: FC = ({ datasetId }) => { const { t } = useTranslation() const [indexState, dispatch] = useReducer(indexStateReducer, { value: 'success' }) - const { data: errorDocs, isLoading } = useSWR({ datasetId }, getErrorDocs) + const { data: errorDocs, isLoading, refetch: refetchErrorDocs } = useDatasetErrorDocs(datasetId) const onRetryErrorDocs = async () => { dispatch({ type: 'retry' }) const document_ids = errorDocs?.data.map((doc: IndexingStatusResponse) => doc.id) || [] const res = await retryErrorDocs({ datasetId, document_ids }) - if (res.result === 'success') + if (res.result === 'success') { + refetchErrorDocs() dispatch({ type: 'success' }) - else + } + else { dispatch({ type: 'error' }) + } } useEffect(() => { diff --git a/web/app/components/datasets/common/image-list/index.tsx b/web/app/components/datasets/common/image-list/index.tsx new file mode 100644 index 0000000000..8b0cf62e4a --- /dev/null +++ b/web/app/components/datasets/common/image-list/index.tsx @@ -0,0 +1,88 @@ +import { useCallback, useMemo, useState } from 'react' +import type { FileEntity } from '@/app/components/base/file-thumb' +import FileThumb from '@/app/components/base/file-thumb' +import cn from '@/utils/classnames' +import More from './more' +import type { ImageInfo } from '../image-previewer' +import ImagePreviewer from '../image-previewer' + +type Image = { + name: string + mimeType: string + sourceUrl: string + size: number + extension: string +} + +type ImageListProps = { + images: Image[] + size: 'sm' | 'md' + limit?: number + className?: string +} + +const ImageList = ({ + images, + size, + limit = 9, + className, +}: ImageListProps) => { + const [showMore, setShowMore] = useState(false) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + + const limitedImages = useMemo(() => { + return showMore ? images : images.slice(0, limit) + }, [images, limit, showMore]) + + const handleShowMore = useCallback(() => { + setShowMore(true) + }, []) + + const handleImageClick = useCallback((file: FileEntity) => { + const index = limitedImages.findIndex(image => image.sourceUrl === file.sourceUrl) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(limitedImages.map(image => ({ + url: image.sourceUrl, + name: image.name, + size: image.size, + }))) + }, [limitedImages]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + return ( + <> +
+ { + limitedImages.map(image => ( + + )) + } + {images.length > limit && !showMore && ( + + )} +
+ {previewImages.length > 0 && ( + + )} + + ) +} + +export default ImageList diff --git a/web/app/components/datasets/common/image-list/more.tsx b/web/app/components/datasets/common/image-list/more.tsx new file mode 100644 index 0000000000..6da85e6939 --- /dev/null +++ b/web/app/components/datasets/common/image-list/more.tsx @@ -0,0 +1,39 @@ +import React, { useCallback } from 'react' + +type MoreProps = { + count: number + onClick?: () => void +} + +const More = ({ count, onClick }: MoreProps) => { + const formatNumber = (num: number) => { + if (num === 0) + return '0' + if (num < 1000) + return num.toString() + if (num < 1000000) + return `${(num / 1000).toFixed(1)}k` + return `${(num / 1000000).toFixed(1)}M` + } + + const handleClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onClick?.() + }, [onClick]) + + return ( +
+
+
+ + {`+${formatNumber(count)}`} + +
+
+
+
+ ) +} + +export default React.memo(More) diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx new file mode 100644 index 0000000000..14e48d65fc --- /dev/null +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -0,0 +1,223 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import Button from '@/app/components/base/button' +import Loading from '@/app/components/base/loading' +import { formatFileSize } from '@/utils/format' +import { RiArrowLeftLine, RiArrowRightLine, RiCloseLine, RiRefreshLine } from '@remixicon/react' +import { createPortal } from 'react-dom' +import { useHotkeys } from 'react-hotkeys-hook' + +type CachedImage = { + blobUrl?: string + status: 'loading' | 'loaded' | 'error' + width: number + height: number +} + +const imageCache = new Map() + +export type ImageInfo = { + url: string + name: string + size: number +} + +type ImagePreviewerProps = { + images: ImageInfo[] + initialIndex?: number + onClose: () => void +} + +const ImagePreviewer = ({ + images, + initialIndex = 0, + onClose, +}: ImagePreviewerProps) => { + const [currentIndex, setCurrentIndex] = useState(initialIndex) + const [cachedImages, setCachedImages] = useState>(() => { + return images.reduce((acc, image) => { + acc[image.url] = { + status: 'loading', + width: 0, + height: 0, + } + return acc + }, {} as Record) + }) + const isMounted = useRef(false) + + const fetchImage = useCallback(async (image: ImageInfo) => { + const { url } = image + // Skip if already cached + if (imageCache.has(url)) return + + try { + const res = await fetch(url) + if (!res.ok) throw new Error(`Failed to load: ${url}`) + const blob = await res.blob() + const blobUrl = URL.createObjectURL(blob) + + const img = new Image() + img.src = blobUrl + img.onload = () => { + if (!isMounted.current) return + imageCache.set(url, { + blobUrl, + status: 'loaded', + width: img.naturalWidth, + height: img.naturalHeight, + }) + setCachedImages((prev) => { + return { + ...prev, + [url]: { + blobUrl, + status: 'loaded', + width: img.naturalWidth, + height: img.naturalHeight, + }, + } + }) + } + } + catch { + if (isMounted.current) { + setCachedImages((prev) => { + return { + ...prev, + [url]: { + status: 'error', + width: 0, + height: 0, + }, + } + }) + } + } + }, []) + + useEffect(() => { + isMounted.current = true + + images.forEach((image) => { + fetchImage(image) + }) + + return () => { + isMounted.current = false + // Cleanup released blob URLs not in current list + imageCache.forEach(({ blobUrl }, key) => { + if (blobUrl) + URL.revokeObjectURL(blobUrl) + imageCache.delete(key) + }) + } + }, []) + + const currentImage = useMemo(() => { + return images[currentIndex] + }, [images, currentIndex]) + + const prevImage = useCallback(() => { + if (currentIndex === 0) + return + setCurrentIndex(prevIndex => prevIndex - 1) + }, [currentIndex]) + + const nextImage = useCallback(() => { + if (currentIndex === images.length - 1) + return + setCurrentIndex(prevIndex => prevIndex + 1) + }, [currentIndex, images.length]) + + const retryImage = useCallback((image: ImageInfo) => { + setCachedImages((prev) => { + return { + ...prev, + [image.url]: { + ...prev[image.url], + status: 'loading', + }, + } + }) + fetchImage(image) + }, [fetchImage]) + + useHotkeys('esc', onClose) + useHotkeys('left', prevImage) + useHotkeys('right', nextImage) + + return createPortal( +
e.stopPropagation()} + tabIndex={-1} + > +
+ + + Esc + +
+ {cachedImages[currentImage.url].status === 'loading' && ( + + )} + {cachedImages[currentImage.url].status === 'error' && ( +
+ {`Failed to load image: ${currentImage.url}. Please try again.`} + +
+ )} + {cachedImages[currentImage.url].status === 'loaded' && ( +
+ {currentImage.name} +
+ {currentImage.name} + · + {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + · + {formatFileSize(currentImage.size)} +
+
+ )} + + +
, + document.body, + ) +} + +export default ImagePreviewer diff --git a/web/app/components/datasets/common/image-uploader/constants.ts b/web/app/components/datasets/common/image-uploader/constants.ts new file mode 100644 index 0000000000..671ed94fcf --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/constants.ts @@ -0,0 +1,7 @@ +export const ACCEPT_TYPES = ['jpg', 'jpeg', 'png', 'gif'] + +export const DEFAULT_IMAGE_FILE_SIZE_LIMIT = 2 + +export const DEFAULT_IMAGE_FILE_BATCH_LIMIT = 5 + +export const DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT = 10 diff --git a/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts new file mode 100644 index 0000000000..aefe48f0cd --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts @@ -0,0 +1,273 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useFileUploadConfig } from '@/service/use-common' +import type { FileEntity, FileUploadConfig } from '../types' +import { getFileType, getFileUploadConfig, traverseFileEntry } from '../utils' +import Toast from '@/app/components/base/toast' +import { useTranslation } from 'react-i18next' +import { ACCEPT_TYPES } from '../constants' +import { useFileStore } from '../store' +import { produce } from 'immer' +import { fileUpload, getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' +import { v4 as uuid4 } from 'uuid' + +export const useUpload = () => { + const { t } = useTranslation() + const fileStore = useFileStore() + + const [dragging, setDragging] = useState(false) + const uploaderRef = useRef(null) + const dragRef = useRef(null) + const dropRef = useRef(null) + + const { data: fileUploadConfigResponse } = useFileUploadConfig() + + const fileUploadConfig: FileUploadConfig = useMemo(() => { + return getFileUploadConfig(fileUploadConfigResponse) + }, [fileUploadConfigResponse]) + + const handleDragEnter = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target !== dragRef.current) + setDragging(true) + } + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + const handleDragLeave = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target === dragRef.current) + setDragging(false) + } + + const checkFileType = useCallback((file: File) => { + const ext = getFileType(file) + return ACCEPT_TYPES.includes(ext.toLowerCase()) + }, []) + + const checkFileSize = useCallback((file: File) => { + const { size } = file + return size <= fileUploadConfig.imageFileSizeLimit * 1024 * 1024 + }, [fileUploadConfig]) + + const showErrorMessage = useCallback((type: 'type' | 'size') => { + if (type === 'type') + Toast.notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + else + Toast.notify({ type: 'error', message: t('dataset.imageUploader.fileSizeLimitExceeded', { size: fileUploadConfig.imageFileSizeLimit }) }) + }, [fileUploadConfig, t]) + + const getValidFiles = useCallback((files: File[]) => { + let validType = true + let validSize = true + const validFiles = files.filter((file) => { + if (!checkFileType(file)) { + validType = false + return false + } + if (!checkFileSize(file)) { + validSize = false + return false + } + return true + }) + if (!validType) + showErrorMessage('type') + else if (!validSize) + showErrorMessage('size') + + return validFiles + }, [checkFileType, checkFileSize, showErrorMessage]) + + const selectHandle = () => { + if (uploaderRef.current) + uploaderRef.current.click() + } + + const handleAddFile = useCallback((newFile: FileEntity) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = produce(files, (draft) => { + draft.push(newFile) + }) + setFiles(newFiles) + }, [fileStore]) + + const handleUpdateFile = useCallback((newFile: FileEntity) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = produce(files, (draft) => { + const index = draft.findIndex(file => file.id === newFile.id) + + if (index > -1) + draft[index] = newFile + }) + setFiles(newFiles) + }, [fileStore]) + + const handleRemoveFile = useCallback((fileId: string) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = files.filter(file => file.id !== fileId) + setFiles(newFiles) + }, [fileStore]) + + const handleReUploadFile = useCallback((fileId: string) => { + const { + files, + setFiles, + } = fileStore.getState() + const index = files.findIndex(file => file.id === fileId) + + if (index > -1) { + const uploadingFile = files[index] + const newFiles = produce(files, (draft) => { + draft[index].progress = 0 + }) + setFiles(newFiles) + fileUpload({ + file: uploadingFile.originalFile!, + onProgressCallback: (progress) => { + handleUpdateFile({ ...uploadingFile, progress }) + }, + onSuccessCallback: (res) => { + handleUpdateFile({ ...uploadingFile, uploadedId: res.id, progress: 100 }) + }, + onErrorCallback: (error?: any) => { + const errorMessage = getFileUploadErrorMessage(error, t('common.fileUploader.uploadFromComputerUploadError'), t) + Toast.notify({ type: 'error', message: errorMessage }) + handleUpdateFile({ ...uploadingFile, progress: -1 }) + }, + }) + } + }, [fileStore, t, handleUpdateFile]) + + const handleLocalFileUpload = useCallback((file: File) => { + const reader = new FileReader() + const isImage = file.type.startsWith('image') + + reader.addEventListener( + 'load', + () => { + const uploadingFile = { + id: uuid4(), + name: file.name, + extension: getFileType(file), + mimeType: file.type, + size: file.size, + progress: 0, + originalFile: file, + base64Url: isImage ? reader.result as string : '', + } + handleAddFile(uploadingFile) + fileUpload({ + file: uploadingFile.originalFile, + onProgressCallback: (progress) => { + handleUpdateFile({ ...uploadingFile, progress }) + }, + onSuccessCallback: (res) => { + handleUpdateFile({ + ...uploadingFile, + extension: res.extension, + mimeType: res.mime_type, + size: res.size, + uploadedId: res.id, + progress: 100, + }) + }, + onErrorCallback: (error?: any) => { + const errorMessage = getFileUploadErrorMessage(error, t('common.fileUploader.uploadFromComputerUploadError'), t) + Toast.notify({ type: 'error', message: errorMessage }) + handleUpdateFile({ ...uploadingFile, progress: -1 }) + }, + }) + }, + false, + ) + reader.addEventListener( + 'error', + () => { + Toast.notify({ type: 'error', message: t('common.fileUploader.uploadFromComputerReadError') }) + }, + false, + ) + reader.readAsDataURL(file) + }, [t, handleAddFile, handleUpdateFile]) + + const handleFileUpload = useCallback((newFiles: File[]) => { + const { files } = fileStore.getState() + const { singleChunkAttachmentLimit } = fileUploadConfig + if (newFiles.length === 0) return + if (files.length + newFiles.length > singleChunkAttachmentLimit) { + Toast.notify({ + type: 'error', + message: t('datasetHitTesting.imageUploader.singleChunkAttachmentLimitTooltip', { limit: singleChunkAttachmentLimit }), + }) + return + } + for (const file of newFiles) + handleLocalFileUpload(file) + }, [fileUploadConfig, fileStore, t, handleLocalFileUpload]) + + const fileChangeHandle = useCallback((e: React.ChangeEvent) => { + const { imageFileBatchLimit } = fileUploadConfig + const files = Array.from(e.target.files ?? []).slice(0, imageFileBatchLimit) + const validFiles = getValidFiles(files) + handleFileUpload(validFiles) + }, [getValidFiles, handleFileUpload, fileUploadConfig]) + + const handleDrop = useCallback(async (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + if (!e.dataTransfer) return + const nested = await Promise.all( + Array.from(e.dataTransfer.items).map((it) => { + const entry = (it as any).webkitGetAsEntry?.() + if (entry) return traverseFileEntry(entry) + const f = it.getAsFile?.() + return f ? Promise.resolve([f]) : Promise.resolve([]) + }), + ) + const files = nested.flat().slice(0, fileUploadConfig.imageFileBatchLimit) + const validFiles = getValidFiles(files) + handleFileUpload(validFiles) + }, [fileUploadConfig, handleFileUpload, getValidFiles]) + + useEffect(() => { + dropRef.current?.addEventListener('dragenter', handleDragEnter) + dropRef.current?.addEventListener('dragover', handleDragOver) + dropRef.current?.addEventListener('dragleave', handleDragLeave) + dropRef.current?.addEventListener('drop', handleDrop) + return () => { + dropRef.current?.removeEventListener('dragenter', handleDragEnter) + dropRef.current?.removeEventListener('dragover', handleDragOver) + dropRef.current?.removeEventListener('dragleave', handleDragLeave) + dropRef.current?.removeEventListener('drop', handleDrop) + } + }, [handleDrop]) + + return { + dragging, + fileUploadConfig, + dragRef, + dropRef, + uploaderRef, + fileChangeHandle, + selectHandle, + handleRemoveFile, + handleReUploadFile, + handleLocalFileUpload, + } +} diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx new file mode 100644 index 0000000000..3e15b92705 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx @@ -0,0 +1,64 @@ +import React from 'react' +import cn from '@/utils/classnames' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useUpload } from '../hooks/use-upload' +import { ACCEPT_TYPES } from '../constants' + +const ImageUploader = () => { + const { t } = useTranslation() + + const { + dragging, + fileUploadConfig, + dragRef, + dropRef, + uploaderRef, + fileChangeHandle, + selectHandle, + } = useUpload() + + return ( +
+ `.${ext}`).join(',')} + onChange={fileChangeHandle} + /> +
+
+ +
+ {t('dataset.imageUploader.button')} + + {t('dataset.imageUploader.browse')} + +
+
+
+ {t('dataset.imageUploader.tip', { + size: fileUploadConfig.imageFileSizeLimit, + supportTypes: ACCEPT_TYPES.join(', '), + batchCount: fileUploadConfig.imageFileBatchLimit, + })} +
+ {dragging &&
} +
+
+ ) +} + +export default React.memo(ImageUploader) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx new file mode 100644 index 0000000000..a5bfb65fa2 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx @@ -0,0 +1,95 @@ +import { + memo, + useCallback, +} from 'react' +import { + RiCloseLine, +} from '@remixicon/react' +import FileImageRender from '@/app/components/base/file-uploader/file-image-render' +import type { FileEntity } from '../types' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { ReplayLine } from '@/app/components/base/icons/src/vender/other' +import { fileIsUploaded } from '../utils' +import Button from '@/app/components/base/button' + +type ImageItemProps = { + file: FileEntity + showDeleteAction?: boolean + onRemove?: (fileId: string) => void + onReUpload?: (fileId: string) => void + onPreview?: (fileId: string) => void +} +const ImageItem = ({ + file, + showDeleteAction, + onRemove, + onReUpload, + onPreview, +}: ImageItemProps) => { + const { id, progress, base64Url, sourceUrl } = file + + const handlePreview = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onPreview?.(id) + }, [onPreview, id]) + + const handleRemove = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onRemove?.(id) + }, [onRemove, id]) + + const handleReUpload = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onReUpload?.(id) + }, [onReUpload, id]) + + return ( +
+ { + showDeleteAction && ( + + ) + } + + { + progress >= 0 && !fileIsUploaded(file) && ( +
+ +
+ ) + } + { + progress === -1 && ( +
+ +
+ ) + } +
+ ) +} + +export default memo(ImageItem) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx new file mode 100644 index 0000000000..3efa3a19d7 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx @@ -0,0 +1,94 @@ +import { + FileContextProvider, + useFileStoreWithSelector, +} from '../store' +import type { FileEntity } from '../types' +import FileItem from './image-item' +import { useUpload } from '../hooks/use-upload' +import ImageInput from './image-input' +import cn from '@/utils/classnames' +import { useCallback, useState } from 'react' +import type { ImageInfo } from '@/app/components/datasets/common/image-previewer' +import ImagePreviewer from '@/app/components/datasets/common/image-previewer' + +type ImageUploaderInChunkProps = { + disabled?: boolean + className?: string +} +const ImageUploaderInChunk = ({ + disabled, + className, +}: ImageUploaderInChunkProps) => { + const files = useFileStoreWithSelector(s => s.files) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + + const handleImagePreview = useCallback((fileId: string) => { + const index = files.findIndex(item => item.id === fileId) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(files.map(item => ({ + url: item.base64Url || item.sourceUrl || '', + name: item.name, + size: item.size, + }))) + }, [files]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + const { + handleRemoveFile, + handleReUploadFile, + } = useUpload() + + return ( +
+ {!disabled && } +
+ { + files.map(file => ( + + )) + } +
+ {previewImages.length > 0 && ( + + )} +
+ ) +} + +export type ImageUploaderInChunkWrapperProps = { + value?: FileEntity[] + onChange: (files: FileEntity[]) => void +} & ImageUploaderInChunkProps + +const ImageUploaderInChunkWrapper = ({ + value, + onChange, + ...props +}: ImageUploaderInChunkWrapperProps) => { + return ( + + + + ) +} + +export default ImageUploaderInChunkWrapper diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx new file mode 100644 index 0000000000..4f230e3957 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx @@ -0,0 +1,64 @@ +import React from 'react' +import { useTranslation } from 'react-i18next' +import { useUpload } from '../hooks/use-upload' +import { ACCEPT_TYPES } from '../constants' +import { useFileStoreWithSelector } from '../store' +import { RiImageAddLine } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' + +const ImageUploader = () => { + const { t } = useTranslation() + const files = useFileStoreWithSelector(s => s.files) + + const { + fileUploadConfig, + uploaderRef, + fileChangeHandle, + selectHandle, + } = useUpload() + + return ( +
+ `.${ext}`).join(',')} + onChange={fileChangeHandle} + /> +
+ +
+
+ +
+ {files.length === 0 && ( + + {t('datasetHitTesting.imageUploader.tip', { + size: fileUploadConfig.imageFileSizeLimit, + batchCount: fileUploadConfig.imageFileBatchLimit, + })} + + )} +
+
+
+
+ ) +} + +export default React.memo(ImageUploader) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx new file mode 100644 index 0000000000..a47356e560 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx @@ -0,0 +1,95 @@ +import { + memo, + useCallback, +} from 'react' +import { + RiCloseLine, +} from '@remixicon/react' +import FileImageRender from '@/app/components/base/file-uploader/file-image-render' +import type { FileEntity } from '../types' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { ReplayLine } from '@/app/components/base/icons/src/vender/other' +import { fileIsUploaded } from '../utils' +import Button from '@/app/components/base/button' + +type ImageItemProps = { + file: FileEntity + showDeleteAction?: boolean + onRemove?: (fileId: string) => void + onReUpload?: (fileId: string) => void + onPreview?: (fileId: string) => void +} +const ImageItem = ({ + file, + showDeleteAction, + onRemove, + onReUpload, + onPreview, +}: ImageItemProps) => { + const { id, progress, base64Url, sourceUrl } = file + + const handlePreview = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onPreview?.(id) + }, [onPreview, id]) + + const handleRemove = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onRemove?.(id) + }, [onRemove, id]) + + const handleReUpload = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onReUpload?.(id) + }, [onReUpload, id]) + + return ( +
+ { + showDeleteAction && ( + + ) + } + + { + progress >= 0 && !fileIsUploaded(file) && ( +
+ +
+ ) + } + { + progress === -1 && ( +
+ +
+ ) + } +
+ ) +} + +export default memo(ImageItem) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx new file mode 100644 index 0000000000..2d04132842 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx @@ -0,0 +1,131 @@ +import { + useCallback, + useState, +} from 'react' +import { + FileContextProvider, +} from '../store' +import type { FileEntity } from '../types' +import { useUpload } from '../hooks/use-upload' +import ImageInput from './image-input' +import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' +import { useFileStoreWithSelector } from '../store' +import ImageItem from './image-item' +import type { ImageInfo } from '@/app/components/datasets/common/image-previewer' +import ImagePreviewer from '@/app/components/datasets/common/image-previewer' + +type ImageUploaderInRetrievalTestingProps = { + textArea: React.ReactNode + actionButton: React.ReactNode + showUploader?: boolean + className?: string + actionAreaClassName?: string +} +const ImageUploaderInRetrievalTesting = ({ + textArea, + actionButton, + showUploader = true, + className, + actionAreaClassName, +}: ImageUploaderInRetrievalTestingProps) => { + const { t } = useTranslation() + const files = useFileStoreWithSelector(s => s.files) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + const { + dragging, + dragRef, + dropRef, + handleRemoveFile, + handleReUploadFile, + } = useUpload() + + const handleImagePreview = useCallback((fileId: string) => { + const index = files.findIndex(item => item.id === fileId) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(files.map(item => ({ + url: item.base64Url || item.sourceUrl || '', + name: item.name, + size: item.size, + }))) + }, [files]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + return ( +
+ {dragging && ( +
+
{t('datasetHitTesting.imageUploader.dropZoneTip')}
+
+
+ )} + {textArea} + { + showUploader && !!files.length && ( +
+ { + files.map(file => ( + + )) + } +
+ ) + } +
+ {showUploader && } + {actionButton} +
+ {previewImages.length > 0 && ( + + )} +
+ ) +} + +export type ImageUploaderInRetrievalTestingWrapperProps = { + value?: FileEntity[] + onChange: (files: FileEntity[]) => void +} & ImageUploaderInRetrievalTestingProps + +const ImageUploaderInRetrievalTestingWrapper = ({ + value, + onChange, + ...props +}: ImageUploaderInRetrievalTestingWrapperProps) => { + return ( + + + + ) +} + +export default ImageUploaderInRetrievalTestingWrapper diff --git a/web/app/components/datasets/common/image-uploader/store.tsx b/web/app/components/datasets/common/image-uploader/store.tsx new file mode 100644 index 0000000000..e3c9e28a84 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/store.tsx @@ -0,0 +1,67 @@ +import { + createContext, + useContext, + useRef, +} from 'react' +import { + create, + useStore, +} from 'zustand' +import type { + FileEntity, +} from './types' + +type Shape = { + files: FileEntity[] + setFiles: (files: FileEntity[]) => void +} + +export const createFileStore = ( + value: FileEntity[] = [], + onChange?: (files: FileEntity[]) => void, +) => { + return create(set => ({ + files: value ? [...value] : [], + setFiles: (files) => { + set({ files }) + onChange?.(files) + }, + })) +} + +type FileStore = ReturnType +export const FileContext = createContext(null) + +export function useFileStoreWithSelector(selector: (state: Shape) => T): T { + const store = useContext(FileContext) + if (!store) + throw new Error('Missing FileContext.Provider in the tree') + + return useStore(store, selector) +} + +export const useFileStore = () => { + return useContext(FileContext)! +} + +type FileProviderProps = { + children: React.ReactNode + value?: FileEntity[] + onChange?: (files: FileEntity[]) => void +} +export const FileContextProvider = ({ + children, + value, + onChange, +}: FileProviderProps) => { + const storeRef = useRef(undefined) + + if (!storeRef.current) + storeRef.current = createFileStore(value, onChange) + + return ( + + {children} + + ) +} diff --git a/web/app/components/datasets/common/image-uploader/types.ts b/web/app/components/datasets/common/image-uploader/types.ts new file mode 100644 index 0000000000..e918f2b41e --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/types.ts @@ -0,0 +1,18 @@ +export type FileEntity = { + id: string + name: string + size: number + extension: string + mimeType: string + progress: number // -1: error, 0 ~ 99: uploading, 100: uploaded + originalFile?: File // used for re-uploading + uploadedId?: string // for uploaded image id + sourceUrl?: string // for uploaded image + base64Url?: string // for image preview during uploading +} + +export type FileUploadConfig = { + imageFileSizeLimit: number + imageFileBatchLimit: number + singleChunkAttachmentLimit: number +} diff --git a/web/app/components/datasets/common/image-uploader/utils.ts b/web/app/components/datasets/common/image-uploader/utils.ts new file mode 100644 index 0000000000..842b279a98 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/utils.ts @@ -0,0 +1,92 @@ +import type { FileUploadConfigResponse } from '@/models/common' +import type { FileEntity } from './types' +import { + DEFAULT_IMAGE_FILE_BATCH_LIMIT, + DEFAULT_IMAGE_FILE_SIZE_LIMIT, + DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, +} from './constants' + +export const getFileType = (currentFile: File) => { + if (!currentFile) + return '' + + const arr = currentFile.name.split('.') + return arr[arr.length - 1] +} + +type FileWithPath = { + relativePath?: string +} & File + +export const traverseFileEntry = (entry: any, prefix = ''): Promise => { + return new Promise((resolve) => { + if (entry.isFile) { + entry.file((file: FileWithPath) => { + file.relativePath = `${prefix}${file.name}` + resolve([file]) + }) + } + else if (entry.isDirectory) { + const reader = entry.createReader() + const entries: any[] = [] + const read = () => { + reader.readEntries(async (results: FileSystemEntry[]) => { + if (!results.length) { + const files = await Promise.all( + entries.map(ent => + traverseFileEntry(ent, `${prefix}${entry.name}/`), + ), + ) + resolve(files.flat()) + } + else { + entries.push(...results) + read() + } + }) + } + read() + } + else { + resolve([]) + } + }) +} + +export const fileIsUploaded = (file: FileEntity) => { + if (file.uploadedId || file.progress === 100) + return true +} + +const getNumberValue = (value: number | string | undefined | null): number => { + if (value === undefined || value === null) + return 0 + if (typeof value === 'number') + return value + if (typeof value === 'string') + return Number(value) + return 0 +} + +export const getFileUploadConfig = (fileUploadConfigResponse: FileUploadConfigResponse | undefined) => { + if (!fileUploadConfigResponse) { + return { + imageFileSizeLimit: DEFAULT_IMAGE_FILE_SIZE_LIMIT, + imageFileBatchLimit: DEFAULT_IMAGE_FILE_BATCH_LIMIT, + singleChunkAttachmentLimit: DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, + } + } + const { + image_file_batch_limit, + single_chunk_attachment_limit, + attachment_image_file_size_limit, + } = fileUploadConfigResponse + const imageFileSizeLimit = getNumberValue(attachment_image_file_size_limit) + const imageFileBatchLimit = getNumberValue(image_file_batch_limit) + const singleChunkAttachmentLimit = getNumberValue(single_chunk_attachment_limit) + return { + imageFileSizeLimit: imageFileSizeLimit > 0 ? imageFileSizeLimit : DEFAULT_IMAGE_FILE_SIZE_LIMIT, + imageFileBatchLimit: imageFileBatchLimit > 0 ? imageFileBatchLimit : DEFAULT_IMAGE_FILE_BATCH_LIMIT, + singleChunkAttachmentLimit: singleChunkAttachmentLimit > 0 ? singleChunkAttachmentLimit : DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, + } +} diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index ed230c52ce..c0952ed4a4 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -20,12 +20,14 @@ import { EffectColor } from '../../settings/chunk-structure/types' type Props = { disabled?: boolean value: RetrievalConfig + showMultiModalTip?: boolean onChange: (value: RetrievalConfig) => void } const RetrievalMethodConfig: FC = ({ disabled = false, value, + showMultiModalTip = false, onChange, }) => { const { t } = useTranslation() @@ -110,6 +112,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.semantic} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} @@ -132,6 +135,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.fullText} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} @@ -155,6 +159,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.hybrid} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 0c28149d56..2b703cc44d 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -24,16 +24,19 @@ import { import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import Toast from '@/app/components/base/toast' import RadioCard from '@/app/components/base/radio-card' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' type Props = { type: RETRIEVE_METHOD value: RetrievalConfig + showMultiModalTip?: boolean onChange: (value: RetrievalConfig) => void } const RetrievalParamConfig: FC = ({ type, value, + showMultiModalTip = false, onChange, }) => { const { t } = useTranslation() @@ -133,19 +136,32 @@ const RetrievalParamConfig: FC = ({
{ value.reranking_enable && ( - { - onChange({ - ...value, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - /> + <> + { + onChange({ + ...value, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + /> + {showMultiModalTip && ( +
+
+
+ +
+ + {t('datasetSettings.form.retrievalSetting.multiModalTip')} + +
+ )} + ) }
@@ -239,19 +255,32 @@ const RetrievalParamConfig: FC = ({ } { value.reranking_mode !== RerankingModeEnum.WeightedScore && ( - { - onChange({ - ...value, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - /> + <> + { + onChange({ + ...value, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + /> + {showMultiModalTip && ( +
+
+
+ +
+ + {t('datasetSettings.form.retrievalSetting.multiModalTip')} + +
+ )} + ) }
diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 7b2eda1dcd..4e78eb2034 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -1,9 +1,7 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' -import useSWR from 'swr' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' -import { omit } from 'lodash-es' import { RiArrowRightLine, RiCheckboxCircleFill, @@ -25,7 +23,7 @@ import type { LegacyDataSourceInfo, ProcessRuleResponse, } from '@/models/datasets' -import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' +import { fetchIndexingStatusBatch as doFetchIndexingStatus } from '@/service/datasets' import { DataSourceType, ProcessMode } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' @@ -40,6 +38,7 @@ import { useInvalidDocumentList } from '@/service/knowledge/use-document' import Divider from '@/app/components/base/divider' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' import Link from 'next/link' +import { useProcessRule } from '@/service/knowledge/use-dataset' type Props = { datasetId: string @@ -207,12 +206,7 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index }, []) // get rule - const { data: ruleDetail } = useSWR({ - action: 'fetchProcessRule', - params: { documentId: getFirstDocument.id }, - }, apiParams => fetchProcessRule(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) + const { data: ruleDetail } = useProcessRule(getFirstDocument?.id) const router = useRouter() const invalidDocumentList = useInvalidDocumentList() diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 4aec0d4082..abe2564ad2 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -2,7 +2,6 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import useSWR from 'swr' import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react' import DocumentFileIcon from '../../common/document-file-icon' import cn from '@/utils/classnames' @@ -11,8 +10,7 @@ import { ToastContext } from '@/app/components/base/toast' import SimplePieChart from '@/app/components/base/simple-pie-chart' import { upload } from '@/service/base' -import { fetchFileUploadConfig } from '@/service/common' -import { fetchSupportFileTypes } from '@/service/datasets' +import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' import { IS_CE_EDITION } from '@/config' @@ -48,8 +46,8 @@ const FileUploader = ({ const fileUploader = useRef(null) const hideUpload = notSupportBatchUpload && fileList.length > 0 - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) - const { data: supportFileTypesResponse } = useSWR({ url: '/files/support-type' }, fetchSupportFileTypes) + const { data: fileUploadConfigResponse } = useFileUploadConfig() + const { data: supportFileTypesResponse } = useFileSupportTypes() const supportTypes = supportFileTypesResponse?.allowed_extensions || [] const supportTypesShowNames = (() => { const extensionMap: { [key: string]: string } = { @@ -68,11 +66,11 @@ const FileUploader = ({ .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') })() const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { - file_size_limit: 15, - batch_count_limit: 5, - file_upload_limit: 5, - }, [fileUploadConfigResponse]) + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: fileUploadConfigResponse?.batch_count_limit ?? 5, + file_upload_limit: fileUploadConfigResponse?.file_upload_limit ?? 5, + }), [fileUploadConfigResponse]) const fileListRef = useRef([]) diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 22d6837754..43be89c326 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { @@ -63,6 +63,7 @@ import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/aler import { noop } from 'lodash-es' import { useDocLink } from '@/context/i18n' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { checkShowMultiModalTip } from '../../settings/utils' const TextLabel: FC = (props) => { return @@ -495,12 +496,6 @@ const StepTwo = ({ setDefaultConfig(data.rules) setLimitMaxChunkLength(data.limits.indexing_max_segmentation_tokens_length) }, - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, }) const getRulesFromDetail = () => { @@ -538,22 +533,8 @@ const StepTwo = ({ setSegmentationType(documentDetail.dataset_process_rule.mode) } - const createFirstDocumentMutation = useCreateFirstDocument({ - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, - }) - const createDocumentMutation = useCreateDocument(datasetId!, { - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, - }) + const createFirstDocumentMutation = useCreateFirstDocument() + const createDocumentMutation = useCreateDocument(datasetId!) const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending const invalidDatasetList = useInvalidDatasetList() @@ -613,6 +594,20 @@ const StepTwo = ({ const isModelAndRetrievalConfigDisabled = !!datasetId && !!currentDataset?.data_source_type + const showMultiModalTip = useMemo(() => { + return checkShowMultiModalTip({ + embeddingModel, + rerankingEnable: retrievalConfig.reranking_enable, + rerankModel: { + rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, + rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, + }, + indexMethod: indexType, + embeddingModelList, + rerankModelList, + }) + }, [embeddingModel, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, indexType, embeddingModelList, rerankModelList]) + return (
@@ -1012,6 +1007,7 @@ const StepTwo = ({ disabled={isModelAndRetrievalConfigDisabled} value={retrievalConfig} onChange={setRetrievalConfig} + showMultiModalTip={showMultiModalTip} /> ) : ( diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx index 868621e1a3..555f2497ef 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx @@ -21,8 +21,6 @@ import dynamic from 'next/dynamic' const SimplePieChart = dynamic(() => import('@/app/components/base/simple-pie-chart'), { ssr: false }) -const FILES_NUMBER_LIMIT = 20 - export type LocalFileProps = { allowedExtensions: string[] notSupportBatchUpload?: boolean @@ -64,10 +62,11 @@ const LocalFile = ({ .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') }, [locale, allowedExtensions]) const ACCEPTS = allowedExtensions.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { - file_size_limit: 15, - batch_count_limit: 5, - }, [fileUploadConfigResponse]) + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: fileUploadConfigResponse?.batch_count_limit ?? 5, + file_upload_limit: fileUploadConfigResponse?.file_upload_limit ?? 5, + }), [fileUploadConfigResponse]) const updateFile = useCallback((fileItem: FileItem, progress: number, list: FileItem[]) => { const { setLocalFileList } = dataSourceStore.getState() @@ -186,11 +185,12 @@ const LocalFile = ({ }, [fileUploadConfig, uploadBatchFiles]) const initialUpload = useCallback((files: File[]) => { + const filesCountLimit = fileUploadConfig.file_upload_limit if (!files.length) return false - if (files.length + localFileList.length > FILES_NUMBER_LIMIT && !IS_CE_EDITION) { - notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.filesNumber', { filesNumber: FILES_NUMBER_LIMIT }) }) + if (files.length + localFileList.length > filesCountLimit && !IS_CE_EDITION) { + notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.filesNumber', { filesNumber: filesCountLimit }) }) return false } @@ -203,7 +203,7 @@ const LocalFile = ({ updateFileList(newFiles) fileListRef.current = newFiles uploadMultipleFiles(preparedFiles) - }, [updateFileList, uploadMultipleFiles, notify, t, localFileList]) + }, [fileUploadConfig.file_upload_limit, localFileList.length, updateFileList, uploadMultipleFiles, notify, t]) const handleDragEnter = (e: DragEvent) => { e.preventDefault() @@ -250,9 +250,10 @@ const LocalFile = ({ updateFileList([...fileListRef.current]) } const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - const files = [...(e.target.files ?? [])] as File[] + let files = [...(e.target.files ?? [])] as File[] + files = files.slice(0, fileUploadConfig.batch_count_limit) initialUpload(files.filter(isValid)) - }, [isValid, initialUpload]) + }, [isValid, initialUpload, fileUploadConfig.batch_count_limit]) const { theme } = useTheme() const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) @@ -305,6 +306,7 @@ const LocalFile = ({ size: fileUploadConfig.file_size_limit, supportTypes: supportTypesShowNames, batchCount: notSupportBatchUpload ? 1 : fileUploadConfig.batch_count_limit, + totalCount: fileUploadConfig.file_upload_limit, })}
{dragging &&
}
diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index 317db84c43..2049ae0d03 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -13,11 +13,10 @@ import Button from '@/app/components/base/button' import type { FileItem } from '@/models/datasets' import { upload } from '@/service/base' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import useSWR from 'swr' -import { fetchFileUploadConfig } from '@/service/common' import SimplePieChart from '@/app/components/base/simple-pie-chart' import { Theme } from '@/types/app' import useTheme from '@/hooks/use-theme' +import { useFileUploadConfig } from '@/service/use-common' export type Props = { file: FileItem | undefined @@ -34,7 +33,7 @@ const CSVUploader: FC = ({ const dropRef = useRef(null) const dragRef = useRef(null) const fileUploader = useRef(null) - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + const { data: fileUploadConfigResponse } = useFileUploadConfig() const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { file_size_limit: 15, }, [fileUploadConfigResponse]) diff --git a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx index 4bed7b461d..c5d3bf5629 100644 --- a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx @@ -13,6 +13,7 @@ type IActionButtonsProps = { actionType?: 'edit' | 'add' handleRegeneration?: () => void isChildChunk?: boolean + showRegenerationButton?: boolean } const ActionButtons: FC = ({ @@ -22,6 +23,7 @@ const ActionButtons: FC = ({ actionType = 'edit', handleRegeneration, isChildChunk = false, + showRegenerationButton = true, }) => { const { t } = useTranslation() const docForm = useDocumentContext(s => s.docForm) @@ -54,7 +56,7 @@ const ActionButtons: FC = ({ ESC
- {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk) + {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk && showRegenerationButton) ?