diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 94e5b0f969..d6f326d4dc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,14 @@ # Backend (default owner, more specific rules below will override) api/ @QuantumGhost +# Backend - MCP +api/core/mcp/ @Nov1c444 +api/core/entities/mcp_provider.py @Nov1c444 +api/services/tools/mcp_tools_manage_service.py @Nov1c444 +api/controllers/mcp/ @Nov1c444 +api/controllers/console/app/mcp_server.py @Nov1c444 +api/tests/**/*mcp* @Nov1c444 + # Backend - Workflow - Engine (Core graph execution engine) api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost api/core/workflow/runtime/ @laipz8200 @QuantumGhost diff --git a/.github/ISSUE_TEMPLATE/refactor.yml b/.github/ISSUE_TEMPLATE/refactor.yml index cf74dcc546..dbe8cbb602 100644 --- a/.github/ISSUE_TEMPLATE/refactor.yml +++ b/.github/ISSUE_TEMPLATE/refactor.yml @@ -1,8 +1,6 @@ -name: "✨ Refactor" -description: Refactor existing code for improved readability and maintainability. -title: "[Chore/Refactor] " -labels: - - refactor +name: "✨ Refactor or Chore" +description: Refactor existing code or perform maintenance chores to improve readability and reliability. +title: "[Refactor/Chore] " body: - type: checkboxes attributes: @@ -11,7 +9,7 @@ body: options: - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). required: true - - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + - label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true @@ -25,14 +23,14 @@ body: id: description attributes: label: Description - placeholder: "Describe the refactor you are proposing." + placeholder: "Describe the refactor or chore you are proposing." validations: required: true - type: textarea id: motivation attributes: label: Motivation - placeholder: "Explain why this refactor is necessary." + placeholder: "Explain why this refactor or chore is necessary." validations: required: false - type: textarea diff --git a/.github/ISSUE_TEMPLATE/tracker.yml b/.github/ISSUE_TEMPLATE/tracker.yml deleted file mode 100644 index 35fedefc75..0000000000 --- a/.github/ISSUE_TEMPLATE/tracker.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: "👾 Tracker" -description: For inner usages, please do not use this template. -title: "[Tracker] " -labels: - - tracker -body: - - type: textarea - id: content - attributes: - label: Blockers - placeholder: "- [ ] ..." - validations: - required: true diff --git a/api/.env.example b/api/.env.example index 35aaabbc10..516a119d98 100644 --- a/api/.env.example +++ b/api/.env.example @@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1 # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 + +# Multimodal knowledgebase limit +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 +IMAGE_FILE_BATCH_LIMIT=10 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b5ffd09d01..a5916241df 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings): default=10, ) + IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a image batch upload operation", + default=10, + ) + + SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a single chunk attachment", + default=10, + ) + + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed image file size for attachments in megabytes", + default=2, + ) + + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field( + description="Timeout for downloading image attachments in seconds", + default=60, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 377297c84c..12ada8b798 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel): class MessageFeedbackPayload(BaseModel): message_id: str = Field(..., description="Message ID") rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") @field_validator("message_id") @classmethod @@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource): db.session.delete(feedback) elif args.rating and feedback: feedback.rating = args.rating + feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource): conversation_id=message.conversation_id, message_id=message.id, rating=rating_value, + content=args.content, from_source="admin", from_account_id=current_user.id, ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1fad8abd52..c0422ef6f4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel): external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None icon_info: dict[str, Any] | None = None + is_multimodal: bool | None = False @field_validator("indexing_technique") @classmethod @@ -423,17 +424,16 @@ class DatasetApi(Resource): payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) payload_data = payload.model_dump(exclude_unset=True) current_user, current_tenant_id = current_account_with_tenant() - # check embedding model setting if ( payload.indexing_technique == "high_quality" and payload.embedding_model_provider is not None and payload.embedding_model is not None ): - DatasetService.check_embedding_model_setting( + is_multimodal = DatasetService.check_is_multimodal_model( dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model ) - + payload.is_multimodal = is_multimodal # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, payload.permission, payload.partial_member_list diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 2520111281..6145da31a5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -424,6 +424,10 @@ class DatasetInitApi(Resource): model_type=ModelType.TEXT_EMBEDDING, model=knowledge_config.embedding_model, ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model + ) + knowledge_config.is_multimodal = is_multimodal except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ee390cbfb7..e73abc2555 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel): content: str answer: str | None = None keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdatePayload(BaseModel): @@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel): answer: str | None = None keywords: list[str] | None = None regenerate_child_chunks: bool = False + attachment_ids: list[str] | None = None class BatchImportPayload(BaseModel): diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index fac90a0135..db7c50f422 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal +from flask_restx import marshal, reqparse from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel): query: str = Field(max_length=250) retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None + attachment_ids: list[str] | None = None class DatasetsHitTestingBase: @@ -54,16 +55,28 @@ class DatasetsHitTestingBase: def hit_testing_args_check(args: dict[str, Any]): HitTestingService.hit_testing_args_check(args) + @staticmethod + def parse_args(): + parser = ( + reqparse.RequestParser() + .add_argument("query", type=str, required=False, location="json") + .add_argument("attachment_ids", type=list, required=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") + ) + return parser.parse_args() + @staticmethod def perform_hit_testing(dataset, args): assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, - query=args["query"], + query=args.get("query"), account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], + retrieval_model=args.get("retrieval_model"), + external_retrieval_model=args.get("external_retrieval_model"), + attachment_ids=args.get("attachment_ids"), limit=10, ) return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index fdd7c2f479..29417dc896 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -45,6 +45,9 @@ class FileApi(Resource): "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, }, 200 @setup_required diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 9a9832dd4a..e2e6c11480 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -83,6 +83,7 @@ class AppRunner: context: str | None = None, memory: TokenBufferMemory | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages @@ -111,6 +112,7 @@ class AppRunner: memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 53188cf506..f8338b226b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent @@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner): ) dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner): context=context, memory=memory, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index e2be4146e1..ddfb5725b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError @@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner): query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner): query=query, context=context, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 14d5f38dcd..d0279349ca 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: document_id, ) continue - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunk_stmt = select(ChildChunk).where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 36b38b7b45..59de4f403d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -7,7 +7,7 @@ import time import uuid from typing import Any -from flask import current_app +from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -89,8 +90,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -136,7 +146,7 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() @@ -152,8 +162,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -209,7 +228,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -302,6 +321,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) documents = index_processor.transform( text_docs, + current_user=None, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), tenant_id=tenant_id, @@ -551,7 +571,10 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 create_keyword_thread = None - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + ): # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, @@ -590,7 +613,7 @@ class IndexingRunner: for future in futures: tokens += future.result() if ( - dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy" and create_keyword_thread is not None ): @@ -635,7 +658,13 @@ class IndexingRunner: db.session.commit() def _process_chunk( - self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + self, + flask_app: Flask, + index_processor: BaseIndexProcessor, + chunk_documents: list[Document], + dataset: Dataset, + dataset_document: DatasetDocument, + embedding_model_instance: ModelInstance | None, ): with flask_app.app_context(): # check document is paused @@ -646,8 +675,15 @@ class IndexingRunner: page_content_list = [document.page_content for document in chunk_documents] tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) + multimodal_documents = [] + for document in chunk_documents: + if document.attachments and dataset.is_multimodal: + multimodal_documents.extend(document.attachments) + # load index - index_processor.load(dataset, chunk_documents, with_keywords=False) + index_processor.load( + dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False + ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).where( @@ -710,6 +746,7 @@ class IndexingRunner: text_docs: list[Document], doc_language: str, process_rule: dict, + current_user: Account | None = None, ) -> list[Document]: # get embedding model instance embedding_model_instance = None @@ -729,6 +766,7 @@ class IndexingRunner: documents = index_processor.transform( text_docs, + current_user, embedding_model_instance=embedding_model_instance, process_rule=process_rule, tenant_id=dataset.tenant_id, @@ -737,14 +775,16 @@ class IndexingRunner: return documents - def _load_segments(self, dataset, dataset_document, documents): + def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]): # save node to document segment doc_store = DatasetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments - doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) + doc_store.add_documents( + docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + ) # update document status to indexing cur_time = naive_utc_now() diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a63e94d59c..5a28bbcc3a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel @@ -200,7 +200,7 @@ class ModelInstance: def invoke_text_embedding( self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke large language model @@ -212,7 +212,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return cast( - TextEmbeddingResult, + EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, model=self.model, @@ -223,6 +223,34 @@ class ModelInstance: ), ) + def invoke_multimodal_embedding( + self, + multimodel_documents: list[dict], + user: str | None = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> EmbeddingResult: + """ + Invoke large language model + + :param multimodel_documents: multimodel documents to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + return cast( + EmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + user=user, + input_type=input_type, + ), + ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding @@ -276,6 +304,40 @@ class ModelInstance: ), ) + def invoke_multimodal_rerank( + self, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if not isinstance(self.model_type_instance, RerankModel): + raise Exception("Model type instance is not RerankModel") + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), + ) + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -461,6 +523,32 @@ class ModelManager: model=default_model_entity.model, ) + def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool: + """ + Check if model supports vision + :param tenant_id: tenant id + :param provider: provider name + :param model: model name + :return: True if model supports vision, False otherwise + """ + model_instance = self.get_model_instance(tenant_id, provider, model_type, model) + model_type_instance = model_instance.model_type_instance + match model_type: + case ModelType.LLM: + model_type_instance = cast(LargeLanguageModel, model_type_instance) + case ModelType.TEXT_EMBEDDING: + model_type_instance = cast(TextEmbeddingModel, model_type_instance) + case ModelType.RERANK: + model_type_instance = cast(RerankModel, model_type_instance) + case _: + raise ValueError(f"Model type {model_type} is not supported") + model_schema = model_type_instance.get_model_schema(model, model_instance.credentials) + if not model_schema: + return False + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + return False + class LBModelManager: def __init__( diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 846b89d658..854c448250 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage): latency: float -class TextEmbeddingResult(BaseModel): +class EmbeddingResult(BaseModel): """ Model class for text embedding result. """ @@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel): model: str embeddings: list[list[float]] usage: EmbeddingUsage + + +class FileEmbeddingResult(BaseModel): + """ + Model class for file embedding result. + """ + + model: str + embeddings: list[list[float]] + usage: EmbeddingUsage diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 36067118b0..0a576b832a 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -50,3 +50,43 @@ class RerankModel(AIModel): ) except Exception as e: raise self._transform_invoke_error(e) + + def invoke_multimodal_rerank( + self, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index bd68ffe903..4c902e2c11 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,7 +2,7 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel): self, model: str, credentials: dict, - texts: list[str], + texts: list[str] | None = None, + multimodel_documents: list[dict] | None = None, user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed + :param files: files to embed :param user: unique user id :param input_type: input type :return: embeddings result @@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel): try: plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) + if texts: + return plugin_model_manager.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + if multimodel_documents: + return plugin_model_manager.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + documents=multimodel_documents, + input_type=input_type, + ) + raise ValueError("No texts or files provided") except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5dfc3c212e..5d70980967 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient): credentials: dict, texts: list[str], input_type: str, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type_=TextEmbeddingResult, + type_=EmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke text embedding") + def invoke_multimodal_embedding( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + documents: list[dict], + input_type: str, + ) -> EmbeddingResult: + """ + Invoke file embedding + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", + type_=EmbeddingResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "text-embedding", + "model": model, + "credentials": credentials, + "documents": documents, + "input_type": input_type, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("Failed to invoke file embedding") + def get_text_embedding_num_tokens( self, tenant_id: str, @@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke rerank") + def invoke_multimodal_rerank( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", + type_=RerankResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "rerank", + "model": model, + "credentials": credentials, + "query": query, + "docs": docs, + "score_threshold": score_threshold, + "top_n": top_n, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise ValueError("Failed to invoke multimodal rerank") + def invoke_tts( self, tenant_id: str, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d1d518a55d..f072092ea7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} @@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) return prompt_messages, stops @@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self._get_last_user_message(query, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files)) else: - prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files)) return prompt_messages, None @@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( @@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self._get_last_user_message(prompt, files, image_detail_config)], stops + return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops def _get_last_user_message( self, prompt: str, files: Sequence["File"], image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> UserPromptMessage: + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + if context_files: + for file in context_files: + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) + if prompt_message_contents: prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index cc946a72c3..bfa8781e9f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -30,9 +31,10 @@ class DataPostProcessor: score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2290de19bc..e644e754ec 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,23 +1,30 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from typing import Any from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -37,14 +44,15 @@ class RetrievalService: retrieval_method: RetrievalMethod, dataset_id: str, query: str, - top_k: int, + top_k: int = 4, score_threshold: float | None = 0.0, reranking_model: dict | None = None, reranking_mode: str = "reranking_model", weights: dict | None = None, document_ids_filter: list[str] | None = None, + attachment_ids: list | None = None, ): - if not query: + if not query and not attachment_ids: return [] dataset = cls._get_dataset(dataset_id) if not dataset: @@ -56,69 +64,52 @@ class RetrievalService: # Optimize multithreading with thread pools with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore futures = [] - if retrieval_method == RetrievalMethod.KEYWORD_SEARCH: + retrieval_service = RetrievalService() + if query: futures.append( executor.submit( - cls.keyword_search, + retrieval_service._retrieve, flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - all_documents=all_documents, - exceptions=exceptions, - document_ids_filter=document_ids_filter, - ) - ) - if RetrievalMethod.is_support_semantic_search(retrieval_method): - futures.append( - executor.submit( - cls.embedding_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, + retrieval_method=retrieval_method, + dataset=dataset, query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, + reranking_mode=reranking_mode, + weights=weights, document_ids_filter=document_ids_filter, + attachment_id=None, + all_documents=all_documents, + exceptions=exceptions, ) ) - if RetrievalMethod.is_support_fulltext_search(retrieval_method): - futures.append( - executor.submit( - cls.full_text_index_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, - document_ids_filter=document_ids_filter, + if attachment_ids: + for attachment_id in attachment_ids: + futures.append( + executor.submit( + retrieval_service._retrieve, + flask_app=current_app._get_current_object(), # type: ignore + retrieval_method=retrieval_method, + dataset=dataset, + query=None, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=reranking_mode, + weights=weights, + document_ids_filter=document_ids_filter, + attachment_id=attachment_id, + all_documents=all_documents, + exceptions=exceptions, + ) ) - ) - concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) + + concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) if exceptions: raise ValueError(";\n".join(exceptions)) - # Deduplicate documents for hybrid search to avoid duplicate chunks - if retrieval_method == RetrievalMethod.HYBRID_SEARCH: - all_documents = cls._deduplicate_documents(all_documents) - data_post_processor = DataPostProcessor( - str(dataset.tenant_id), reranking_mode, reranking_model, weights, False - ) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k, - ) - return all_documents @classmethod @@ -223,6 +214,7 @@ class RetrievalService: retrieval_method: RetrievalMethod, exceptions: list, document_ids_filter: list[str] | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ): with flask_app.app_context(): try: @@ -231,14 +223,30 @@ class RetrievalService: raise ValueError("dataset not found") vector = Vector(dataset=dataset) - documents = vector.search_by_vector( - query, - search_type="similarity_score_threshold", - top_k=top_k, - score_threshold=score_threshold, - filter={"group_id": [dataset.id]}, - document_ids_filter=document_ids_filter, - ) + documents = [] + if query_type == QueryType.TEXT_QUERY: + documents.extend( + vector.search_by_vector( + query, + search_type="similarity_score_threshold", + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) + if query_type == QueryType.IMAGE_QUERY: + if not dataset.is_multimodal: + return + documents.extend( + vector.search_by_file( + file_id=query, + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) if documents: if ( @@ -250,14 +258,37 @@ class RetrievalService: data_post_processor = DataPostProcessor( str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) - all_documents.extend( - data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents), + if dataset.is_multimodal: + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=dataset.tenant_id, + provider=reranking_model.get("reranking_provider_name") or "", + model=reranking_model.get("reranking_model_name") or "", + model_type=ModelType.RERANK, + ) + if is_support_vision: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) + ) + else: + # not effective, return original documents + all_documents.extend(documents) + else: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) ) - ) else: all_documents.extend(documents) except Exception as e: @@ -339,103 +370,159 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - - # Process documents - for document in documents: - document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - - dataset_document = dataset_documents[document_id] - if not dataset_document: - continue - - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents - child_index_node_id = document.metadata.get("doc_id") - child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) - child_chunk = db.session.scalar(child_chunk_stmt) - - if not child_chunk: + segment_file_map = {} + with Session(db.engine) as session: + # Process documents + for document in documents: + segment_id = None + attachment_info = None + child_chunk = None + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: continue - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == child_chunk.segment_id, - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, + dataset_document = dataset_documents[document_id] + if not dataset_document: + continue + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + # Handle parent-child documents + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attchment_info"] + segment_id = attachment_info_dict["segment_id"] + else: + child_index_node_id = document.metadata.get("doc_id") + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = session.scalar(child_chunk_stmt) + + if not child_chunk: + continue + segment_id = child_chunk.segment_id + + if not segment_id: + continue + + segment = ( + session.query(DocumentSegment) + .where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + .options( + load_only( + DocumentSegment.id, + DocumentSegment.content, + DocumentSegment.answer, + ) + ) + .first() ) - .first() - ) - if not segment: - continue + if not segment: + continue - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + if attachment_info: + segment_file_map[segment.id].append(attachment_info) else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - else: - # Handle normal documents - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - segment = db.session.scalar(document_segment_stmt) + # Handle normal documents + segment = None + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attchment_info"] + segment_id = attachment_info_dict["segment_id"] + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + segment = session.scalar(document_segment_stmt) + if segment: + segment_file_map[segment.id] = [attachment_info] + else: + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + segment = session.scalar(document_segment_stmt) - if not segment: - continue - - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score"), # type: ignore - } - records.append(record) + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score"), # type: ignore + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if attachment_info: + attachment_infos = segment_file_map.get(segment.id, []) + if attachment_info not in attachment_infos: + attachment_infos.append(attachment_info) + segment_file_map[segment.id] = attachment_infos # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] + if record["segment"].id in segment_file_map: + record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] result = [] for record in records: @@ -447,6 +534,11 @@ class RetrievalService: if not isinstance(child_chunks, list): child_chunks = None + # Extract files, ensuring it's a list or None + files = record.get("files") + if not isinstance(files, list): + files = None + # Extract score, ensuring it's a float or None score_value = record.get("score") score = ( @@ -456,10 +548,149 @@ class RetrievalService: ) # Create RetrievalSegments object - retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) + retrieval_segment = RetrievalSegments( + segment=segment, child_chunks=child_chunks, score=score, files=files + ) result.append(retrieval_segment) return result except Exception as e: db.session.rollback() raise e + + def _retrieve( + self, + flask_app: Flask, + retrieval_method: RetrievalMethod, + dataset: Dataset, + query: str | None = None, + top_k: int = 4, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, + reranking_mode: str = "reranking_model", + weights: dict | None = None, + document_ids_filter: list[str] | None = None, + attachment_id: str | None = None, + all_documents: list[Document] = [], + exceptions: list[str] = [], + ): + if not query and not attachment_id: + return + with flask_app.app_context(): + all_documents_item: list[Document] = [] + # Optimize multithreading with thread pools + with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore + futures = [] + if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query: + futures.append( + executor.submit( + self.keyword_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + all_documents=all_documents_item, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + if query: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.TEXT_QUERY, + ) + ) + if attachment_id: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=attachment_id, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.IMAGE_QUERY, + ) + ) + if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query: + futures.append( + executor.submit( + self.full_text_index_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + + if exceptions: + raise ValueError(";\n".join(exceptions)) + + # Deduplicate documents for hybrid search to avoid duplicate chunks + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: + if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE: + all_documents.extend(all_documents_item) + all_documents_item = self._deduplicate_documents(all_documents_item) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) + + query = query or attachment_id + if not query: + return + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, + ) + + all_documents.extend(all_documents_item) + + @classmethod + def get_segment_attachment_info( + cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session + ) -> dict[str, Any] | None: + upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() + if upload_file: + attachment_binding = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.attachment_id == upload_file.id) + .first() + ) + if attachment_binding: + attchment_info = { + "id": upload_file.id, + "name": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "size": upload_file.size, + } + return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id} + return None diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0beb388693..3a47241293 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,4 @@ +import base64 import logging import time from abc import ABC, abstractmethod @@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings +from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.dataset import Dataset, Whitelist +from models.model import UploadFile logger = logging.getLogger(__name__) @@ -203,6 +207,47 @@ class Vector: self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) + def create_multimodal(self, file_documents: list | None = None, **kwargs): + if file_documents: + start = time.time() + logger.info("start embedding %s files %s", len(file_documents), start) + batch_size = 1000 + total_batches = len(file_documents) + batch_size - 1 + for i in range(0, len(file_documents), batch_size): + batch = file_documents[i : i + batch_size] + batch_start = time.time() + logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch)) + + # Batch query all upload files to avoid N+1 queries + attachment_ids = [doc.metadata["doc_id"] for doc in batch] + stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids)) + upload_files = db.session.scalars(stmt).all() + upload_file_map = {str(f.id): f for f in upload_files} + + file_base64_list = [] + real_batch = [] + for document in batch: + attachment_id = document.metadata["doc_id"] + doc_type = document.metadata["doc_type"] + upload_file = upload_file_map.get(attachment_id) + if upload_file: + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + file_base64_list.append( + { + "content": file_base64_str, + "content_type": doc_type, + "file_id": attachment_id, + } + ) + real_batch.append(document) + batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list) + logger.info( + "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start + ) + self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs) + logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start) + def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) @@ -223,6 +268,22 @@ class Vector: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) + def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + return [] + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + multimodal_vector = self._embeddings.embed_multimodal_query( + { + "content": file_base64_str, + "content_type": DocType.IMAGE, + "file_id": file_id, + } + ) + return self._vector_processor.search_by_vector(multimodal_vector, **kwargs) + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 74a2653e9d..1fe74d3042 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -5,9 +5,9 @@ from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding class DatasetDocumentStore: @@ -120,6 +120,9 @@ class DatasetDocumentStore: db.session.add(segment_document) db.session.flush() + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): @@ -144,6 +147,9 @@ class DatasetDocumentStore: segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child and doc.children: # delete the existing child chunks db.session.query(ChildChunk).where( @@ -233,3 +239,15 @@ class DatasetDocumentStore: document_segment = db.session.scalar(stmt) return document_segment + + def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): + if multimodel_documents: + for multimodel_document in multimodel_documents: + binding = SegmentAttachmentBinding( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_id, + attachment_id=multimodel_document.metadata["doc_id"], + ) + db.session.add(binding) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 7fb20c1941..3cbc7db75d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings): return text_embeddings + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + # use doc embedding cache or store if not exists + multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] + embedding_queue_indices = [] + for i, multimodel_document in enumerate(multimodel_documents): + file_id = multimodel_document["file_id"] + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider + ) + .first() + ) + if embedding: + multimodel_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + + # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work + + if embedding_queue_indices: + embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) + for i in range(0, len(embedding_queue_multimodel_documents), max_chunks): + batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks] + + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=batch_multimodel_documents, + user=self._user, + input_type=EmbeddingInputType.DOCUMENT, + ) + + for vector in embedding_result.embeddings: + try: + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning("Normalized embedding is nan: %s", normalized_embedding) + continue + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception: + logger.exception("Failed transform embedding") + cache_embeddings = [] + try: + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + multimodel_embeddings[i] = n_embedding + file_id = multimodel_documents[i]["file_id"] + if file_id not in cache_embeddings: + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=file_id, + provider_name=self._model_instance.provider, + embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), + ) + embedding_cache.set_embedding(n_embedding) + db.session.add(embedding_cache) + cache_embeddings.append(file_id) + db.session.commit() + except IntegrityError: + db.session.rollback() + except Exception as ex: + db.session.rollback() + logger.exception("Failed to embed documents") + raise ex + + return multimodel_embeddings + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists @@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings): raise ex return embedding_results # type: ignore + + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal documents.""" + # use doc embedding cache or store if not exists + file_id = multimodel_document["file_id"] + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}" + embedding = redis_client.get(embedding_cache_key) + if embedding: + redis_client.expire(embedding_cache_key, 600) + decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") + return [float(x) for x in decoded_embedding] + try: + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + ) + + embedding_results = embedding_result.embeddings[0] + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") + except Exception as ex: + if dify_config.DEBUG: + logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"]) + raise ex + + try: + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 600, encoded_str) + except Exception as ex: + if dify_config.DEBUG: + logger.exception( + "Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"] + ) + raise ex + + return embedding_results # type: ignore diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 9f232ab910..1be55bda80 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -9,11 +9,21 @@ class Embeddings(ABC): """Embed search docs.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + raise NotImplementedError + @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal query.""" + raise NotImplementedError + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" raise NotImplementedError diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 8e92191568..b54a37b49e 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None + files: list[dict[str, str | int]] | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index aca879df7d..9f66cd9a03 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel): page: int | None = None doc_metadata: dict[str, Any] | None = None title: str | None = None + files: list[dict[str, Any]] | None = None diff --git a/api/core/rag/index_processor/constant/doc_type.py b/api/core/rag/index_processor/constant/doc_type.py new file mode 100644 index 0000000000..93c8fecb8d --- /dev/null +++ b/api/core/rag/index_processor/constant/doc_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class DocType(StrEnum): + TEXT = "text" + IMAGE = "image" diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index 659086e808..09617413f7 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,7 +1,12 @@ from enum import StrEnum -class IndexType(StrEnum): +class IndexStructureType(StrEnum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" PARENT_CHILD_INDEX = "hierarchical_model" + + +class IndexTechniqueType(StrEnum): + ECONOMY = "economy" + HIGH_QUALITY = "high_quality" diff --git a/api/core/rag/index_processor/constant/query_type.py b/api/core/rag/index_processor/constant/query_type.py new file mode 100644 index 0000000000..342bfef3f7 --- /dev/null +++ b/api/core/rag/index_processor/constant/query_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class QueryType(StrEnum): + TEXT_QUERY = "text_query" + IMAGE_QUERY = "image_query" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d4eff53204..8a28eb477a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,20 +1,34 @@ """Abstract interface for document loader implementations.""" +import cgi +import logging +import mimetypes +import os +import re from abc import ABC, abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional +from urllib.parse import unquote, urlparse + +import httpx from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.models.document import Document +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import AttachmentDocument, Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter +from extensions.ext_database import db +from extensions.ext_storage import storage +from models import Account, ToolFile from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from models.model import UploadFile if TYPE_CHECKING: from core.model_manager import ModelInstance @@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): raise NotImplementedError @abstractmethod @@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC): ) return character_splitter # type: ignore + + def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]: + """ + Get the content files from the document. + """ + multi_model_documents: list[AttachmentDocument] = [] + text = document.page_content + images = self._extract_markdown_images(text) + if not images: + return multi_model_documents + upload_file_id_list = [] + + for image in images: + # Collect all upload_file_ids including duplicates to preserve occurrence count + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + # Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + if current_user: + tool_file_id = match.group(1) + upload_file_id = self._download_tool_file(tool_file_id, current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + continue + if current_user: + upload_file_id = self._download_image(image.split(" ")[0], current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + + if not upload_file_id_list: + return multi_model_documents + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + + # Create a mapping from ID to UploadFile for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} + + # Create a Document for each occurrence (including duplicates) + for upload_file_id in upload_file_id_list: + upload_file = upload_file_map.get(upload_file_id) + if upload_file: + multi_model_documents.append( + AttachmentDocument( + page_content=upload_file.name, + metadata={ + "doc_id": upload_file.id, + "doc_hash": "", + "document_id": document.metadata.get("document_id"), + "dataset_id": document.metadata.get("dataset_id"), + "doc_type": DocType.IMAGE, + }, + ) + ) + return multi_model_documents + + def _extract_markdown_images(self, text: str) -> list[str]: + """ + Extract the markdown images from the text. + """ + pattern = r"!\[.*?\]\((.*?)\)" + return re.findall(pattern, text) + + def _download_image(self, image_url: str, current_user: Account) -> str | None: + """ + Download the image from the URL. + Image size must not exceed 2MB. + """ + from services.file_service import FileService + + MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT + + try: + # Download with timeout + response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT) + response.raise_for_status() + + # Check Content-Length header if available + content_length = response.headers.get("Content-Length") + if content_length and int(content_length) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length) + return None + + filename = None + + content_disposition = response.headers.get("content-disposition") + if content_disposition: + _, params = cgi.parse_header(content_disposition) + if "filename" in params: + filename = params["filename"] + filename = unquote(filename) + + if not filename: + parsed_url = urlparse(image_url) + # unquote 处理 URL 中的中文 + path = unquote(parsed_url.path) + filename = os.path.basename(path) + + if not filename: + filename = "downloaded_image_file" + + name, current_ext = os.path.splitext(filename) + + content_type = response.headers.get("content-type", "").split(";")[0].strip() + + real_ext = mimetypes.guess_extension(content_type) + + if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext: + filename = f"{name}{real_ext}" + # Download content with size limit + blob = b"" + for chunk in response.iter_bytes(chunk_size=8192): + blob += chunk + if len(blob) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit during download", image_url) + return None + + if not blob: + logging.warning("Image from %s is empty", image_url) + return None + + upload_file = FileService(db.engine).upload_file( + filename=filename, + content=blob, + mimetype=content_type, + user=current_user, + ) + return upload_file.id + except httpx.TimeoutException: + logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT) + return None + except httpx.RequestError as e: + logging.warning("Error downloading image from %s: %s", image_url, str(e)) + return None + except Exception: + logging.exception("Unexpected error downloading image from %s", image_url) + return None + + def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None: + """ + Download the tool file from the ID. + """ + from services.file_service import FileService + + tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + if not tool_file: + return None + blob = storage.load_once(tool_file.file_key) + upload_file = FileService(db.engine).upload_file( + filename=tool_file.name, + content=blob, + mimetype=tool_file.mimetype, + user=current_user, + ) + return upload_file.id diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c987edf342..ea6ab24699 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor @@ -19,11 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX: + if self._index_type == IndexStructureType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX: + elif self._index_type == IndexStructureType.QA_INDEX: return QAIndexProcessor() - elif self._index_type == IndexType.PARENT_CHILD_INDEX: + elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX: return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5e5fea7ea9..a7c879f2c4 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule @@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if document_node.metadata is not None: document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_hash"] = hash + multimodal_documents = ( + self._get_content_files(document_node, current_user) if document_node.metadata else None + ) + if multimodal_documents: + document_node.attachments = multimodal_documents # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: @@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) with_keywords = False if with_keywords: keywords_list = kwargs.get("keywords_list") @@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + documents: list[Any] = [] + all_multimodal_documents: list[Any] = [] if isinstance(chunks, list): - documents = [] for content in chunks: metadata = { "dataset_id": dataset.id, @@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor): "doc_hash": helper.generate_text_hash(content), } doc = Document(page_content=content, metadata=metadata) + attachments = self._get_content_files(doc) + if attachments: + doc.attachments = attachments + all_multimodal_documents.extend(attachments) documents.append(doc) - if documents: - # save node to document segment - doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) - # add document segments - doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": - vector = Vector(dataset) - vector.create(documents) - elif dataset.indexing_technique == "economy": - keyword = Keyword(dataset) - keyword.add_texts(documents) else: - raise ValueError("Chunks is not a list") + multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks) + for general_chunk in multimodal_general_structure.general_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(general_chunk.content), + } + doc = Document(page_content=general_chunk.content, metadata=metadata) + if general_chunk.files: + attachments = [] + for file in general_chunk.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument( + page_content=file.filename or "image_file", metadata=file_metadata + ) + attachments.append(file_document) + all_multimodal_documents.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + if all_multimodal_documents: + vector.create_multimodal(all_multimodal_documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)} + return { + "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, + "preview": preview, + "total_segments": len(chunks), + } else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 4fa78e2f95..ee29d2fd65 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk +from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from libs import helper +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): page_content = page_content if len(page_content) > 0: document_node.page_content = page_content + multimodel_documents = self._get_content_files(document_node, current_user) + if multimodel_documents: + document_node.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): elif rules.parent_mode == ParentMode.FULL_DOC: page_content = "\n".join([document.page_content for document in documents]) document = Document(page_content=page_content, metadata=documents[0].metadata) + multimodel_documents = self._get_content_files(document) + if multimodel_documents: + document.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) for document in documents: @@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids @@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + if parent_child.files and len(parent_child.files) > 0: + attachments = [] + for file in parent_child.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata) + attachments.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) documents.append(doc) if documents: # update document parent mode @@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store.add_documents(docs=documents, save_child=True) if dataset.indexing_technique == "high_quality": all_child_documents = [] + all_multimodal_documents = [] for doc in documents: if doc.children: all_child_documents.extend(doc.children) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + vector = Vector(dataset) if all_child_documents: - vector = Vector(dataset) vector.create(all_child_documents) + if all_multimodal_documents: + vector.create_multimodal(all_multimodal_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: parent_childs = ParentChildStructureChunk.model_validate(chunks) @@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) return { - "chunk_structure": IndexType.PARENT_CHILD_INDEX, + "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 3e3deb0180..1183d5fbd7 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document, QAStructureChunk +from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor): ) return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") if not process_rule: @@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor): try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file) # type: ignore text_docs = [] for _, row in df.iterrows(): data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]}) @@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) @@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor): for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) return { - "chunk_structure": IndexType.QA_INDEX, + "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4bd7b1d62e..611fad9a18 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,6 +4,8 @@ from typing import Any from pydantic import BaseModel, Field +from core.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" @@ -15,7 +17,19 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AttachmentDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + provider: str | None = "dify" + + vector: list[float] | None = None + + metadata: dict[str, Any] = Field(default_factory=dict) class Document(BaseModel): @@ -28,12 +42,31 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) provider: str | None = "dify" children: list[ChildDocument] | None = None + attachments: list[AttachmentDocument] | None = None + + +class GeneralChunk(BaseModel): + """ + General Chunk. + """ + + content: str + files: list[File] | None = None + + +class MultimodalGeneralStructureChunk(BaseModel): + """ + Multimodal General Structure Chunk. + """ + + general_chunks: list[GeneralChunk] + class GeneralStructureChunk(BaseModel): """ @@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel): parent_content: str child_contents: list[str] + files: list[File] | None = None class ParentChildStructureChunk(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 3561def008..88acb75133 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -12,6 +13,7 @@ class BaseRerankRunner(ABC): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index e855b0083f..38309d3d77 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,6 +1,15 @@ -from core.model_manager import ModelInstance +import base64 + +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.rerank_entities import RerankResult +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile class RerankModelRunner(BaseRerankRunner): @@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, + provider=self.rerank_model_instance.provider, + model=self.rerank_model_instance.model, + model_type=ModelType.RERANK, + ) + if not is_support_vision: + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + else: + return documents + else: + rerank_result, unique_documents = self.fetch_multimodal_rerank( + query, documents, score_threshold, top_n, user, query_type + ) + + rerank_documents = [] + for result in rerank_result.docs: + if score_threshold is None or result.score >= score_threshold: + # format document + rerank_document = Document( + page_content=result.text, + metadata=unique_documents[result.index].metadata, + provider=unique_documents[result.index].provider, + ) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) + return rerank_documents[:top_n] if top_n else rerank_documents + + def fetch_text_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch text rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ docs = [] doc_ids = set() unique_documents = [] @@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - docs.append(document.page_content) - unique_documents.append(document) + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + docs.append(document.page_content) + unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append(document.page_content) unique_documents.append(document) - documents = unique_documents - rerank_result = self.rerank_model_instance.invoke_rerank( query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) + return rerank_result, unique_documents - rerank_documents = [] + def fetch_multimodal_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch multimodal rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :param query_type: query type + :return: rerank result + """ + docs = [] + doc_ids = set() + unique_documents = [] + for document in documents: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): + if document.metadata.get("doc_type") == DocType.IMAGE: + # Query file info within db.session context to ensure thread-safe access + upload_file = ( + db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() + ) + if upload_file: + blob = storage.load_once(upload_file.key) + document_file_base64 = base64.b64encode(blob).decode() + document_file_dict = { + "content": document_file_base64, + "content_type": document.metadata["doc_type"], + } + docs.append(document_file_dict) + else: + document_text_dict = { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + docs.append(document_text_dict) + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append( + { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + ) + unique_documents.append(document) - for result in rerank_result.docs: - if score_threshold is None or result.score >= score_threshold: - # format document - rerank_document = Document( - page_content=result.text, - metadata=documents[result.index].metadata, - provider=documents[result.index].provider, + documents = unique_documents + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + return rerank_result, unique_documents + elif query_type == QueryType.IMAGE_QUERY: + # Query file info within db.session context to ensure thread-safe access + upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + if upload_file: + blob = storage.load_once(upload_file.key) + file_query = base64.b64encode(blob).decode() + file_query_dict = { + "content": file_query, + "content_type": DocType.IMAGE, + } + rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) - if rerank_document.metadata is not None: - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + return rerank_result, unique_documents + else: + raise ValueError(f"Upload file not found for query: {query}") - rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) - return rerank_documents[:top_n] if top_n else rerank_documents + else: + raise ValueError(f"Query type {query_type} is not supported") diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index c455db6095..18020608cb 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -7,6 +7,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - unique_documents.append(document) + # weight rerank only support text documents + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) else: if document not in unique_documents: unique_documents.append(document) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3db67efb0e..ec55d2d0cc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -8,6 +8,7 @@ from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import and_, or_, select +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.file import File, FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -99,7 +105,8 @@ class DatasetRetrieval: message_id: str, memory: TokenBufferMemory | None = None, inputs: Mapping[str, Any] | None = None, - ) -> str | None: + vision_enabled: bool = False, + ) -> tuple[str | None, list[File] | None]: """ Retrieve dataset. :param app_id: app_id @@ -118,7 +125,7 @@ class DatasetRetrieval: """ dataset_ids = config.dataset_ids if len(dataset_ids) == 0: - return None + return None, [] retrieve_config = config.retrieve_config # check model is support tool calling @@ -136,7 +143,7 @@ class DatasetRetrieval: ) if not model_schema: - return None + return None, [] planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features @@ -182,8 +189,8 @@ class DatasetRetrieval: tenant_id, user_id, user_from, - available_datasets, query, + available_datasets, model_instance, model_config, planning_strategy, @@ -213,6 +220,7 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] document_context_list: list[DocumentContext] = [] + context_files: list[File] = [] retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: @@ -248,6 +256,31 @@ class DatasetRetrieval: score=record.score, ) ) + if vision_enabled: + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment.id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attchment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=segment.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attchment_info) if show_retrieve_source: for record in records: segment = record.segment @@ -288,8 +321,10 @@ class DatasetRetrieval: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" + return str( + "\n".join([document_context.content for document_context in document_context_list]) + ), context_files + return "", context_files def single_retrieve( self, @@ -297,8 +332,8 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, query: str, + available_datasets: list, model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -336,7 +371,7 @@ class DatasetRetrieval: dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) self._record_usage(router_usage) - + timer = None if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -406,10 +441,19 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrieval_end(results, message_id, timer) + thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": results, + "message_id": message_id, + "timer": timer, + }, + ) + thread.start() return results return [] @@ -421,7 +465,7 @@ class DatasetRetrieval: user_id: str, user_from: str, available_datasets: list, - query: str, + query: str | None, top_k: int, score_threshold: float, reranking_mode: str, @@ -431,10 +475,11 @@ class DatasetRetrieval: message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): if not available_datasets: return [] - threads = [] + all_threads = [] all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( @@ -467,131 +512,226 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model - - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: - continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - with measure_time() as timer: - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - - all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + if query: + query_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": query, + "attachment_id": None, + }, ) - else: - if index_type == "economy": - all_documents = self.calculate_keyword_score(query, all_documents, top_k) - elif index_type == "high_quality": - all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) - else: - all_documents = all_documents[:top_k] if top_k else all_documents - - self._on_query(query, dataset_ids, app_id, user_from, user_id) + all_threads.append(query_thread) + query_thread.start() + if attachment_ids: + for attachment_id in attachment_ids: + attachment_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": None, + "attachment_id": attachment_id, + }, + ) + all_threads.append(attachment_thread) + attachment_thread.start() + for thread in all_threads: + thread.join() + self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrieval_end(all_documents, message_id, timer) - - return all_documents - - def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): - """Handle retrieval end.""" - dify_documents = [document for document in documents if document.provider == "dify"] - for document in dify_documents: - if document.metadata is not None: - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == document.metadata["document_id"] - ) - dataset_document = db.session.scalar(dataset_document_stmt) - if dataset_document: - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = db.session.scalar(child_chunk_stmt) - if child_chunk: - _ = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) - ) - else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) - - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) - - db.session.commit() - - # get tracing instance - trace_manager: TraceQueueManager | None = ( - self.application_generate_entity.trace_manager if self.application_generate_entity else None - ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer - ) + # add thread to call _on_retrieval_end + retrieval_end_thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": all_documents, + "message_id": message_id, + "timer": timer, + }, ) + retrieval_end_thread.start() + retrieval_resource_list = [] + doc_ids_filter = [] + for document in all_documents: + if document.provider == "dify": + doc_id = document.metadata.get("doc_id") + if doc_id and doc_id not in doc_ids_filter: + doc_ids_filter.append(doc_id) + retrieval_resource_list.append(document) + elif document.provider == "external": + retrieval_resource_list.append(document) + return retrieval_resource_list - def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str): + def _on_retrieval_end( + self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + ): + """Handle retrieval end.""" + with flask_app.app_context(): + dify_documents = [document for document in documents if document.provider == "dify"] + segment_ids = [] + segment_index_node_ids = [] + with Session(db.engine) as session: + for document in dify_documents: + if document.metadata is not None: + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == document.metadata["document_id"] + ) + dataset_document = session.scalar(dataset_document_stmt) + if dataset_document: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + segment_id = None + if ( + "doc_type" not in document.metadata + or document.metadata.get("doc_type") == DocType.TEXT + ): + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + child_chunk = session.scalar(child_chunk_stmt) + if child_chunk: + segment_id = child_chunk.segment_id + elif ( + "doc_type" in document.metadata + and document.metadata.get("doc_type") == DocType.IMAGE + ): + attachment_info_dict = RetrievalService.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + segment_id = attachment_info_dict["segment_id"] + if segment_id: + if segment_id not in segment_ids: + segment_ids.append(segment_id) + _ = ( + session.query(DocumentSegment) + .where(DocumentSegment.id == segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + ) + else: + query = None + if ( + "doc_type" not in document.metadata + or document.metadata.get("doc_type") == DocType.TEXT + ): + if document.metadata["doc_id"] not in segment_index_node_ids: + segment = ( + session.query(DocumentSegment) + .where(DocumentSegment.index_node_id == document.metadata["doc_id"]) + .first() + ) + if segment: + segment_index_node_ids.append(document.metadata["doc_id"]) + segment_ids.append(segment.id) + query = session.query(DocumentSegment).where( + DocumentSegment.id == segment.id + ) + elif ( + "doc_type" in document.metadata + and document.metadata.get("doc_type") == DocType.IMAGE + ): + attachment_info_dict = RetrievalService.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + segment_id = attachment_info_dict["segment_id"] + if segment_id not in segment_ids: + segment_ids.append(segment_id) + query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id) + if query: + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.where( + DocumentSegment.dataset_id == document.metadata["dataset_id"] + ) + + # add hit count to document segment + query.update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + + db.session.commit() + + # get tracing instance + trace_manager: TraceQueueManager | None = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer + ) + ) + + def _on_query( + self, + query: str | None, + attachment_ids: list[str] | None, + dataset_ids: list[str], + app_id: str, + user_from: str, + user_id: str, + ): """ Handle query. """ - if not query: + if not query and not attachment_ids: return dataset_queries = [] for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source="app", - source_app_id=app_id, - created_by_role=user_from, - created_by=user_id, - ) - dataset_queries.append(dataset_query) - if dataset_queries: - db.session.add_all(dataset_queries) + contents = [] + if query: + contents.append({"content_type": QueryType.TEXT_QUERY, "content": query}) + if attachment_ids: + for attachment_id in attachment_ids: + contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}) + if contents: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=json.dumps(contents), + source="app", + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id, + ) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever( @@ -603,6 +743,7 @@ class DatasetRetrieval: all_documents: list, document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -611,7 +752,7 @@ class DatasetRetrieval: if not dataset: return [] - if dataset.provider == "external": + if dataset.provider == "external" and query: external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset_id, @@ -663,6 +804,7 @@ class DatasetRetrieval: reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, + attachment_ids=attachment_ids, ) all_documents.extend(documents) @@ -1222,3 +1364,86 @@ class DatasetRetrieval: usage = LLMUsage.empty_usage() return full_text, usage + + def _multiple_retrieve_thread( + self, + flask_app: Flask, + available_datasets: list, + metadata_condition: MetadataCondition | None, + metadata_filter_document_ids: dict[str, list[str]] | None, + all_documents: list[Document], + tenant_id: str, + reranking_enable: bool, + reranking_mode: str, + reranking_model: dict | None, + weights: dict[str, Any] | None, + top_k: int, + score_threshold: float, + query: str | None, + attachment_id: str | None, + ): + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: + continue + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + + if reranking_enable: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) + else: + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json new file mode 100644 index 0000000000..1a07869662 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json @@ -0,0 +1,65 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "array", + "title": "Multimodal General Structure", + "description": "Schema for multimodal general structure (v1) - array of objects", + "properties": { + "general_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "description": "List of files" + } + } + }, + "required": ["content"] + }, + "description": "List of content and files" + } + } +} \ No newline at end of file diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json new file mode 100644 index 0000000000..4ffb590519 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json @@ -0,0 +1,78 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "object", + "title": "Multimodal Parent-Child Structure", + "description": "Schema for multimodal parent-child structure (v1)", + "properties": { + "parent_mode": { + "type": "string", + "description": "The mode of parent-child relationship" + }, + "parent_child_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "parent_content": { + "type": "string", + "description": "The parent content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"] + }, + "description": "List of files" + }, + "child_contents": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of child contents" + } + }, + "required": ["parent_content", "child_contents"] + }, + "description": "List of parent-child chunk pairs" + } + }, + "required": ["parent_mode", "parent_child_chunks"] +} \ No newline at end of file diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 5cdf473542..fef3157f27 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" +def sign_upload_file(upload_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url for plugin access + """ + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 105823f896..80c69e94c8 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str: """ # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+" return re.sub(pattern, "", text) diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index ebf93f2fc2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -3,6 +3,7 @@ from datetime import datetime from pydantic import Field +from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.pause_reason import PauseReason @@ -14,6 +15,7 @@ from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") + context_files: list[File] | None = Field(default=None, description="context files") class ModelInvokeCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8aa6a5016f..86bb2495e7 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ type: str = "knowledge-retrieval" - query_variable_selector: list[str] + query_variable_selector: list[str] | None | str = None + query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: MultipleRetrievalConfig | None = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1b57d23e24..adc474bd60 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import ( + ArrayFileSegment, + FileSegment, StringSegment, ) from core.variables.segments import ArrayObjectSegment @@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return "1" def _run(self) -> NodeRunResult: - # extract variables - variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) - if not isinstance(variable, StringSegment): + if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, - error="Query variable is not string type.", - ) - query = variable.value - variables = {"query": query} - if not query: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." + process_data={}, + outputs={}, + metadata={}, + llm_usage=LLMUsage.empty_usage(), ) + variables: dict[str, Any] = {} + # extract variables + if self._node_data.query_variable_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) + if not isinstance(variable, StringSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value + variables["query"] = query + + if self._node_data.query_attachment_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector) + if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Attachments variable is not array file or file type.", + ) + if isinstance(variable, ArrayFileSegment): + variables["attachments"] = variable.value + else: + variables["attachments"] = [variable.value] + # TODO(-LAN-): Move this check outside. # check rate limit knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) @@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # retrieve knowledge usage = LLMUsage.empty_usage() try: - results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD db.session.close() def _fetch_dataset_retriever( - self, node_data: KnowledgeRetrievalNodeData, query: str + self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[dict[str, Any]], LLMUsage]: usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids - + query = variables.get("query") + attachments = variables.get("attachments") + metadata_filter_document_ids = None + metadata_condition = None + metadata_usage = LLMUsage.empty_usage() # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) @@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( - [dataset.id for dataset in available_datasets], query, node_data - ) - usage = self._merge_usage(usage, metadata_usage) + if query: + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( + [dataset.id for dataset in available_datasets], query, node_data + ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, + attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) usage = self._merge_usage(usage, dataset_retrieval.llm_usage) @@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = [] # deal with external documents for item in external_documents: - source = { + source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = { "metadata": { "_source": "knowledge", "dataset_id": item.metadata.get("dataset_id"), @@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD "doc_metadata": document.doc_metadata, }, "title": document.name, + "files": list(record.files) if record.files else None, } if segment.answer: source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" @@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + key=self._score, # type: ignore[arg-type, return-value] reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): - item["metadata"]["position"] = position + item["metadata"]["position"] = position # type: ignore[index] return retrieval_resource_list, usage + def _score(self, item: dict[str, Any]) -> float: + meta = item.get("metadata") + if isinstance(meta, dict): + s = meta.get("score") + if isinstance(s, (int, float)): + return float(s) + return 0.0 + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: @@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) variable_mapping = {} - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1a2473e0bb..a5973862b2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,8 +7,10 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from sqlalchemy import select + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file import FileType, file_manager +from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.signature import sign_upload_file from core.variables import ( ArrayFileSegment, ArraySegment, @@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from extensions.ext_database import db +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile from . import llm_utils from .entities import ( @@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]): # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None + context_files: list[File] = [] for event in generator: context = event.context + context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context + if context_files: + node_inputs["#context_files#"] = [file.model_dump() for file in context_files] + # fetch model config model_instance, model_config = LLMNode._fetch_model_config( node_data_model=self.node_data.model, @@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, + context_files=context_files, ) # handle invoke result @@ -322,6 +334,7 @@ class LLMNode(Node[LLMNodeData]): inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, + llm_usage=usage, ) ) except Exception as e: @@ -332,6 +345,8 @@ class LLMNode(Node[LLMNodeData]): error=str(e), inputs=node_inputs, process_data=process_data, + error_type=type(e).__name__, + llm_usage=usage, ) ) @@ -654,10 +669,13 @@ class LLMNode(Node[LLMNodeData]): context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + yield RunRetrieverResourceEvent( + retriever_resources=[], context=context_value_variable.value, context_files=[] + ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] + context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -670,9 +688,34 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) - + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attchment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attchment_info) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, context=context_str.strip() + retriever_resources=original_retriever_resource, + context=context_str.strip(), + context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: @@ -700,6 +743,7 @@ class LLMNode(Node[LLMNodeData]): content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), + files=context_dict.get("files"), ) return source @@ -741,6 +785,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, + context_files: list["File"] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -853,6 +898,23 @@ class LLMNode(Node[LLMNodeData]): else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index db3d4d4aac..4a3e8e56f8 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), + error_type=type(e).__name__, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 89c4d8fba9..1e5ec7d200 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -97,11 +97,27 @@ dataset_detail_fields = { "total_documents": fields.Integer, "total_available_documents": fields.Integer, "enable_api": fields.Boolean, + "is_multimodal": fields.Boolean, +} + +file_info_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + +content_fields = { + "content_type": fields.String, + "content": fields.String, + "file_info": fields.Nested(file_info_fields, allow_null=True), } dataset_query_detail_fields = { "id": fields.String, - "content": fields.String, + "queries": fields.Nested(content_fields), "source": fields.String, "source_app_id": fields.String, "created_by_role": fields.String, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index c12ebc09c8..a707500445 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -9,6 +9,8 @@ upload_config_fields = { "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, "workflow_file_upload_limit": fields.Integer, + "image_file_batch_limit": fields.Integer, + "single_chunk_attachment_limit": fields.Integer, } diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 75bdff1803..e70f9fa722 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -43,9 +43,19 @@ child_chunk_fields = { "score": fields.Float, } +files_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, + "files": fields.List(fields.Nested(files_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2ff917d6bc..56d6b68378 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -13,6 +13,15 @@ child_chunk_fields = { "updated_at": TimestampField, } +attachment_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -39,4 +48,5 @@ segment_fields = { "error": fields.String, "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), + "attachments": fields.List(fields.Nested(attachment_fields)), } diff --git a/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py new file mode 100644 index 0000000000..187bf7136d --- /dev/null +++ b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py @@ -0,0 +1,57 @@ +"""support-multi-modal + +Revision ID: d57accd375ae +Revises: 03f8dcbc611e +Create Date: 2025-11-12 15:37:12.363670 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd57accd375ae' +down_revision = '7bb281b7a422' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('segment_attachment_bindings', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('attachment_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey') + ) + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.create_index( + 'segment_attachment_binding_tenant_dataset_document_segment_idx', + ['tenant_id', 'dataset_id', 'document_id', 'segment_id'], + unique=False + ) + batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('is_multimodal') + + + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.drop_index('segment_attachment_binding_attachment_idx') + batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx') + + op.drop_table('segment_attachment_bindings') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index e072711b82..ba2eaf6749 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -76,6 +78,7 @@ class Dataset(Base): pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false")) @property def total_documents(self): @@ -728,9 +731,7 @@ class DocumentSegment(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() - ) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(LongText, nullable=True) @@ -866,6 +867,47 @@ class DocumentSegment(Base): return text + @property + def attachments(self) -> list[dict[str, Any]]: + # Use JOIN to fetch attachments in a single query instead of two separate queries + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == self.tenant_id, + SegmentAttachmentBinding.dataset_id == self.dataset_id, + SegmentAttachmentBinding.document_id == self.document_id, + SegmentAttachmentBinding.segment_id == self.id, + ) + ).all() + if not attachments_with_bindings: + return [] + attachment_list = [] + for _, attachment in attachments_with_bindings: + upload_file_id = attachment.id + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + reference_url = dify_config.CONSOLE_API_URL or "" + base_url = f"{reference_url}/files/{upload_file_id}/image-preview" + source_url = f"{base_url}?{params}" + attachment_list.append( + { + "id": attachment.id, + "name": attachment.name, + "size": attachment.size, + "extension": attachment.extension, + "mime_type": attachment.mime_type, + "source_url": source_url, + } + ) + return attachment_list + class ChildChunk(Base): __tablename__ = "child_chunks" @@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase): DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) + @property + def queries(self) -> list[dict[str, Any]]: + try: + queries = json.loads(self.content) + if isinstance(queries, list): + for query in queries: + if query["content_type"] == QueryType.IMAGE_QUERY: + file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + if file_info: + query["file_info"] = { + "id": file_info.id, + "name": file_info.name, + "size": file_info.size, + "extension": file_info.extension, + "mime_type": file_info.mime_type, + "source_url": sign_upload_file(file_info.id, file_info.extension), + } + else: + query["file_info"] = None + + return queries + else: + return [queries] + except JSONDecodeError: + return [ + { + "content_type": QueryType.TEXT_QUERY, + "content": self.content, + "file_info": None, + } + ] + class DatasetKeywordTable(TypeBase): __tablename__ = "dataset_keyword_tables" @@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase): onupdate=func.current_timestamp(), init=False, ) + + +class SegmentAttachmentBinding(Base): + __tablename__ = "segment_attachment_bindings" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), + sa.Index( + "segment_attachment_binding_tenant_dataset_document_segment_idx", + "tenant_id", + "dataset_id", + "document_id", + "segment_id", + ), + sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), + ) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index 1731ff5699..6b0bf4b4a2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -259,7 +259,7 @@ class App(Base): provider_id = tool.get("provider_id", "") if provider_type == ToolProviderType.API: - if uuid.UUID(provider_id) not in existing_api_providers: + if provider_id not in existing_api_providers: deleted_tools.append( { "type": ToolProviderType.API, diff --git a/api/models/workflow.py b/api/models/workflow.py index 42ee8a1f2b..853d5afefc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -907,19 +907,29 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @property def extras(self) -> dict[str, Any]: from core.tools.tool_manager import ToolManager + from core.trigger.trigger_manager import TriggerManager extras: dict[str, Any] = {} - if self.execution_metadata_dict: - if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict: - tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] + execution_metadata = self.execution_metadata_dict + if execution_metadata: + if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata: + tool_info: dict[str, Any] = execution_metadata["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict: - datasource_info = self.execution_metadata_dict["datasource_info"] + elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata: + datasource_info = execution_metadata["datasource_info"] extras["icon"] = datasource_info.get("icon") + elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata: + trigger_info = execution_metadata["trigger_info"] or {} + provider_id = trigger_info.get("provider_id") + if provider_id: + extras["icon"] = TriggerManager.get_trigger_plugin_icon( + tenant_id=self.tenant_id, + provider_id=provider_id, + ) return extras def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py new file mode 100644 index 0000000000..2bd5627d5e --- /dev/null +++ b/api/services/attachment_service.py @@ -0,0 +1,31 @@ +import base64 + +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from extensions.ext_storage import storage +from models.model import UploadFile + +PREVIEW_WORDS_LIMIT = 3000 + + +class AttachmentService: + _session_maker: sessionmaker + + def __init__(self, session_factory: sessionmaker | Engine | None = None): + if isinstance(session_factory, Engine): + self._session_maker = sessionmaker(bind=session_factory) + elif isinstance(session_factory, sessionmaker): + self._session_maker = session_factory + else: + raise AssertionError("must be a sessionmaker or an Engine.") + + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index bb09311349..00f06e9405 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Literal, cast import sqlalchemy as sa from redis.exceptions import LockNotOwnedError @@ -19,9 +19,10 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted @@ -46,6 +47,7 @@ from models.dataset import ( DocumentSegment, ExternalKnowledgeBindings, Pipeline, + SegmentAttachmentBinding, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -363,6 +365,27 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): + try: + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=model, + ) + text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance) + model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema: + raise ValueError("Model schema not found") + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + else: + return False + except LLMBadRequestError: + raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") + @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: @@ -402,13 +425,13 @@ class DatasetService: if not dataset: raise ValueError("Dataset not found") # check if dataset name is exists - - if DatasetService._has_dataset_same_name( - tenant_id=dataset.tenant_id, - dataset_id=dataset_id, - name=data.get("name", dataset.name), - ): - raise ValueError("Dataset name already exists") + if data.get("name") and data.get("name") != dataset.name: + if DatasetService._has_dataset_same_name( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + name=data.get("name", dataset.name), + ): + raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) @@ -844,6 +867,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model or "", ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -880,6 +909,12 @@ class DatasetService: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model.model ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id dataset.indexing_technique = knowledge_configuration.indexing_technique except LLMBadRequestError: @@ -937,6 +972,12 @@ class DatasetService: ) ) dataset.collection_binding_id = dataset_collection_binding.id + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -2305,6 +2346,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + is_multimodal=knowledge_config.is_multimodal, ) db.session.add(dataset) @@ -2685,6 +2727,13 @@ class SegmentService: if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") + if args.get("attachment_ids"): + if not isinstance(args["attachment_ids"], list): + raise ValueError("Attachment IDs is invalid") + single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT + if len(args["attachment_ids"]) > single_chunk_attachment_limit: + raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") + @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): assert isinstance(current_user, Account) @@ -2731,11 +2780,23 @@ class SegmentService: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] - db.session.add(segment_document) - # update document word count - assert document.word_count is not None - document.word_count += segment_document.word_count - db.session.add(document) + db.session.add(segment_document) + # update document word count + assert document.word_count is not None + document.word_count += segment_document.word_count + db.session.add(document) + db.session.commit() + + if args["attachment_ids"]: + for attachment_id in args["attachment_ids"]: + binding = SegmentAttachmentBinding( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + segment_id=segment_document.id, + attachment_id=attachment_id, + ) + db.session.add(binding) db.session.commit() # save vector index @@ -2899,7 +2960,7 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance if dataset.indexing_technique == "high_quality": @@ -2926,12 +2987,11 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) @@ -2976,7 +3036,7 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting @@ -3002,15 +3062,15 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) - + # update multimodel vector index + VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: logger.exception("update segment index failed") segment.enabled = False @@ -3048,7 +3108,9 @@ class SegmentService: ) child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids + ) db.session.delete(segment) # update document word count @@ -3097,7 +3159,9 @@ class SegmentService: # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids + ) if document.word_count is None: document.word_count = 0 diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 81e0c0ecd4..eeb14072bd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -29,8 +29,14 @@ def get_current_user(): from models.account import Account from models.model import EndUser - if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore - raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + try: + user_object = current_user._get_current_object() + except AttributeError: + # Handle case where current_user might not be a LocalProxy in test environments + user_object = current_user + + if not isinstance(user_object, (Account, EndUser)): + raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}") return current_user diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 131e90e195..7959734e89 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None name: str | None = None + is_multimodal: bool = False + + +class SegmentCreateArgs(BaseModel): + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdateArgs(BaseModel): @@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False enabled: bool | None = None + attachment_ids: list[str] | None = None class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/file_service.py b/api/services/file_service.py index 1980cd8d59..0911cf38c4 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,3 +1,4 @@ +import base64 import hashlib import os import uuid @@ -123,6 +124,15 @@ class FileService: return file_size <= file_size_limit + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index dfb49cf2bd..8e8e78f83f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,3 +1,4 @@ +import json import logging import time from typing import Any @@ -5,6 +6,7 @@ from typing import Any from core.app.app_config.entities import ModelConfig from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -32,6 +34,7 @@ class HitTestingService: account: Account, retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, + attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() @@ -41,7 +44,7 @@ class HitTestingService: retrieval_model = dataset.retrieval_model or default_retrieval_model document_ids_filter = None metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions: + if metadata_filtering_conditions and query: dataset_retrieval = DatasetRetrieval() from core.app.app_config.entities import MetadataFilteringCondition @@ -66,6 +69,7 @@ class HitTestingService: retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), dataset_id=dataset.id, query=query, + attachment_ids=attachment_ids, top_k=retrieval_model.get("top_k", 4), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] @@ -80,17 +84,24 @@ class HitTestingService: end = time.perf_counter() logger.debug("Hit testing retrieve in %s seconds", end - start) - - dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source="hit_testing", - source_app_id=None, - created_by_role="account", - created_by=account.id, - ) - - db.session.add(dataset_query) + dataset_queries = [] + if query: + content = {"content_type": QueryType.TEXT_QUERY, "content": query} + dataset_queries.append(content) + if attachment_ids: + for attachment_id in attachment_ids: + content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id} + dataset_queries.append(content) + if dataset_queries: + dataset_query = DatasetQuery( + dataset_id=dataset.id, + content=json.dumps(dataset_queries), + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, + ) + db.session.add(dataset_query) db.session.commit() return cls.compact_retrieve_response(query, all_documents) @@ -168,9 +179,14 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): query = args["query"] + attachment_ids = args["attachment_ids"] - if not query or len(query) > 250: - raise ValueError("Query is required and cannot exceed 250 characters") + if not attachment_ids and not query: + raise ValueError("Query or attachment_ids is required") + if query and len(query) > 250: + raise ValueError("Query cannot exceed 250 characters") + if attachment_ids and not isinstance(attachment_ids, list): + raise ValueError("Attachment_ids must be a list") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/vector_service.py b/api/services/vector_service.py index abc92a0181..f1fa33cb75 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -21,9 +24,10 @@ class VectorService: cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] + multimodal_documents: list[AttachmentDocument] = [] for segment in segments: - if doc_form == IndexType.PARENT_CHILD_INDEX: + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() if not dataset_document: logger.warning( @@ -70,12 +74,29 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) documents.append(rag_document) + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_document: AttachmentDocument = AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + multimodal_documents.append(multimodal_document) + index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor() + if len(documents) > 0: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list) + if len(multimodal_documents) > 0: + index_processor.load(dataset, [], multimodal_documents, with_keywords=False) @classmethod def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): @@ -130,6 +151,7 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) # use full doc mode to generate segment's child chunk @@ -226,3 +248,92 @@ class VectorService: def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) vector.delete_by_ids([child_chunk.index_node_id]) + + @classmethod + def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): + if dataset.indexing_technique != "high_quality": + return + + attachments = segment.attachments + old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else [] + + # Check if there's any actual change needed + if set(attachment_ids) == set(old_attachment_ids): + return + + try: + vector = Vector(dataset=dataset) + if dataset.is_multimodal: + # Delete old vectors if they exist + if old_attachment_ids: + vector.delete_by_ids(old_attachment_ids) + + # Delete existing segment attachment bindings in one operation + db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( + synchronize_session=False + ) + + if not attachment_ids: + db.session.commit() + return + + # Bulk fetch upload files - only fetch needed fields + upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + + if not upload_file_list: + db.session.commit() + return + + # Create a mapping for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list} + + # Prepare batch operations + bindings = [] + documents = [] + + # Create common metadata base to avoid repetition + base_metadata = { + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + } + + # Process attachments in the order specified by attachment_ids + for attachment_id in attachment_ids: + upload_file = upload_file_map.get(attachment_id) + if not upload_file: + logger.warning("Upload file not found for attachment_id: %s", attachment_id) + continue + + # Create segment attachment binding + bindings.append( + SegmentAttachmentBinding( + tenant_id=segment.tenant_id, + dataset_id=segment.dataset_id, + document_id=segment.document_id, + segment_id=segment.id, + attachment_id=upload_file.id, + ) + ) + + # Create document for vector indexing + documents.append( + Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id}) + ) + + # Bulk insert all bindings at once + if bindings: + db.session.add_all(bindings) + + # Add documents to vector store if any + if documents and dataset.is_multimodal: + vector.add_texts(documents, duplicate_check=True) + + # Single commit for all operations + db.session.commit() + + except Exception: + logger.exception("Failed to update multimodal vector for segment %s", segment.id) + db.session.rollback() + raise diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 933ad6b9e2..e7dead8a56 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str): ) documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 5f2a355d16..8608df6b8e 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -18,6 +18,7 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + SegmentAttachmentBinding, ) from models.model import UploadFile @@ -58,14 +59,20 @@ def clean_dataset_task( ) documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) + ).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexType + from core.rag.index_processor.constant.index_type import IndexStructureType - doc_form = IndexType.PARAGRAPH_INDEX + doc_form = IndexStructureType.PARAGRAPH_INDEX logger.info( click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") ) @@ -90,6 +97,7 @@ def clean_dataset_task( for document in documents: db.session.delete(document) + # delete document file for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) @@ -107,6 +115,19 @@ def clean_dataset_task( ) db.session.delete(image_file) db.session.delete(segment) + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 62200715cc..6d2feb1da3 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage -from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment +from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile logger = logging.getLogger(__name__) @@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i raise Exception("Document has no dataset") segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.exception("Delete file failed when document deleted, file_id: %s", file_id) db.session.delete(file) db.session.commit() + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) # delete dataset metadata binding db.session.query(DatasetMetadataBinding).where( diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 713f149c38..3d13afdec0 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task # type: ignore -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": dataset_documents = ( @@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index dc6ef6fb61..1c7de3b1ce 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,14 +1,14 @@ import logging import time -from typing import Literal import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): +def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index :param dataset_id: dataset_id @@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index e8cbd0f250..bea5c952cf 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,14 +6,15 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, SegmentAttachmentBinding +from models.model import UploadFile logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_segment_from_index_task( - index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None + index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None ): """ Async Remove segment from index @@ -49,6 +50,21 @@ def delete_segment_from_index_task( delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, ) + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + db.session.delete(binding) + # delete upload file + db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + db.session.commit() end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 9038dc179b..c2a3de29f4 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -8,7 +8,7 @@ from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument logger = logging.getLogger(__name__) @@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen try: index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) end_at = time.perf_counter() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 07c44f333e..7615469ed0 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str): ) child_documents.append(child_document) document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + # save vector index - index_processor.load(dataset, [document]) + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index c5ca7a6171..9f17d09e18 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,9 +5,10 @@ import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i try: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i ) child_documents.append(child_document) document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 9478bb9ddb..088d6ba6ba 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify the expected outcomes # Verify index processor was called correctly - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes @@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() @@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask: # Act: Execute the task add_document_to_index_task(document.id) - # Assert: Verify index processing occurred with all completed segments - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + # Assert: Verify index processing occurred but with empty documents list + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments @@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask: assert len(remaining_logs) == 0 # Verify index processing occurred normally - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify segments were enabled @@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify only eligible segments were processed - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 94e9b76965..37d886f569 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask: document.updated_at = fake.date_time_this_year() document.doc_type = kwargs.get("doc_type", "text") document.doc_metadata = kwargs.get("doc_metadata", {}) - document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") db_session_with_containers.add(document) @@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask: mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + # Extract segment IDs for the task + segment_ids = [segment.id for segment in segments] + # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None # Task should return None on success @@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent dataset - result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found @@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent document - result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when document not found @@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with disabled document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled @@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with archived document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is archived @@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with incomplete indexing - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when indexing is not completed @@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask: fake = Faker() # Test different document forms - document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + document_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in document_forms: # Create test data for each document form @@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None @@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor to raise an exception mock_processor = MagicMock() @@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - should not raise exception - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without raising exceptions assert result is None # Task should return None even when exceptions occur @@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, []) # Verify the task completed successfully assert result is None @@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask: # Create large number of segments segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 798fe091ab..b738646736 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(segment_ids, dataset.id, document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id) # Assert: Verify index processor was created but load was not called - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( @@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index d9f6dcc43c..025a0d8d70 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, @@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments: @pytest.fixture def sample_embedding_result(self): - """Create a sample TextEmbeddingResult for testing. + """Create a sample EmbeddingResult for testing. Returns: - TextEmbeddingResult: Mock embedding result with proper structure + EmbeddingResult: Mock embedding result with proper structure """ # Create normalized embedding vectors (dimension 1536 for ada-002) embedding_vector = np.random.randn(1536) @@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_vector], usage=usage, @@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments: latency=0.6, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=new_embeddings, usage=usage, @@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[valid_vector.tolist(), nan_vector], usage=usage, @@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[nan_vector], usage=usage, @@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage, ) - result_3_small = TextEmbeddingResult( + result_3_small = EmbeddingResult( model="text-embedding-3-small", embeddings=[normalized_3_small], usage=usage, @@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching: latency=0.4, ) - result_openai = TextEmbeddingResult( + result_openai = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_openai], usage=usage_openai, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation: latency=0.7, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage_ada, @@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation: latency=0.4, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases: latency=0.1, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases: latency=1.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases: latency=0.2, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases: ) # Model returns embeddings for all texts - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index d26e98db8d..c00fee8fe5 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -62,7 +62,7 @@ from core.indexing_runner import ( IndexingRunner, ) from core.model_runtime.entities.model_entities import ModelType -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule @@ -112,7 +112,7 @@ def create_mock_dataset_document( document_id: str | None = None, dataset_id: str | None = None, tenant_id: str | None = None, - doc_form: str = IndexType.PARAGRAPH_INDEX, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, data_source_type: str = "upload_file", doc_language: str = "English", ) -> Mock: @@ -133,8 +133,8 @@ def create_mock_dataset_document( Mock: A configured mock DatasetDocument object with all required attributes. Example: - >>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX) - >>> assert doc.doc_form == IndexType.QA_INDEX + >>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX) + >>> assert doc.doc_form == IndexStructureType.QA_INDEX """ doc = Mock(spec=DatasetDocument) doc.id = document_id or str(uuid.uuid4()) @@ -276,7 +276,7 @@ class TestIndexingRunnerExtract: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} return doc @@ -616,7 +616,7 @@ class TestIndexingRunnerLoad: doc = Mock(spec=DatasetDocument) doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -700,7 +700,7 @@ class TestIndexingRunnerLoad: """Test loading with parent-child index structure.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX sample_dataset.indexing_technique = "high_quality" # Add child documents @@ -775,7 +775,7 @@ class TestIndexingRunnerRun: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} @@ -802,6 +802,21 @@ class TestIndexingRunnerRun: mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule + # Mock current_user (Account) for _transform + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + # Setup db.session.query to return different results based on the model + def mock_query_side_effect(model): + mock_query_result = MagicMock() + if model.__name__ == "Dataset": + mock_query_result.filter_by.return_value.first.return_value = mock_dataset + elif model.__name__ == "Account": + mock_query_result.filter_by.return_value.first.return_value = mock_current_user + return mock_query_result + + mock_dependencies["db"].session.query.side_effect = mock_query_side_effect + # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.created_by = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments: """Test loading segments for parent-child index.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX # Add child documents for doc in sample_documents: @@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate: tenant_id=tenant_id, extract_settings=extract_settings, tmp_processing_rule={"mode": "automatic", "rules": {}}, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 4912884c55..ebe6c37818 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +def create_mock_model_instance(): + """Create a properly configured mock ModelInstance for reranking tests.""" + mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" + return mock_instance + + class TestRerankModelRunner: """Unit tests for RerankModelRunner. @@ -37,10 +49,23 @@ class TestRerankModelRunner: - Metadata preservation and score injection """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" return mock_instance @pytest.fixture @@ -803,7 +828,7 @@ class TestRerankRunnerFactory: - Parameters are forwarded to runner constructor """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner via factory runner = RerankRunnerFactory.create_rerank_runner( @@ -865,7 +890,7 @@ class TestRerankRunnerFactory: - String values are properly matched """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner using enum value runner = RerankRunnerFactory.create_rerank_runner( @@ -886,6 +911,13 @@ class TestRerankIntegration: - Real-world usage scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_model_reranking_full_workflow(self): """Test complete model-based reranking workflow. @@ -895,7 +927,7 @@ class TestRerankIntegration: - Top results are returned correctly """ # Arrange: Create mock model and documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -951,7 +983,7 @@ class TestRerankIntegration: - Normalization is consistent """ # Arrange: Create mock model with various scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -991,6 +1023,13 @@ class TestRerankEdgeCases: - Concurrent reranking scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_with_empty_metadata(self): """Test reranking when documents have empty metadata. @@ -1000,7 +1039,7 @@ class TestRerankEdgeCases: - Empty metadata documents are processed correctly """ # Arrange: Create documents with empty metadata - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1046,7 +1085,7 @@ class TestRerankEdgeCases: - Score comparison logic works at boundary """ # Arrange: Create mock with various scores including negatives - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1082,7 +1121,7 @@ class TestRerankEdgeCases: - No overflow or precision issues """ # Arrange: All documents with perfect scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1117,7 +1156,7 @@ class TestRerankEdgeCases: - Content encoding is preserved """ # Arrange: Documents with special characters - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1159,7 +1198,7 @@ class TestRerankEdgeCases: - Content is not truncated unexpectedly """ # Arrange: Documents with very long content - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() long_content = "This is a very long document. " * 1000 # ~30,000 characters mock_rerank_result = RerankResult( @@ -1196,7 +1235,7 @@ class TestRerankEdgeCases: - All documents are processed correctly """ # Arrange: Create 100 documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() num_docs = 100 # Create rerank results for all documents @@ -1287,7 +1326,7 @@ class TestRerankEdgeCases: - Documents can still be ranked """ # Arrange: Empty query - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1325,6 +1364,13 @@ class TestRerankPerformance: - Score calculation optimization """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_batch_processing(self): """Test that documents are processed in a single batch. @@ -1334,7 +1380,7 @@ class TestRerankPerformance: - Efficient batch processing """ # Arrange: Multiple documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)], @@ -1435,6 +1481,13 @@ class TestRerankErrorHandling: - Error propagation """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_model_invocation_error(self): """Test handling of model invocation errors. @@ -1444,7 +1497,7 @@ class TestRerankErrorHandling: - Error context is preserved """ # Arrange: Mock model that raises exception - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed") documents = [ @@ -1470,7 +1523,7 @@ class TestRerankErrorHandling: - Invalid results don't corrupt output """ # Arrange: Rerank result with invalid index - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 0163e42992..affd6c648f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -425,15 +425,15 @@ class TestRetrievalService: # ==================== Vector Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents): + def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic vector/semantic search functionality. This test validates the core vector search flow: 1. Dataset is retrieved from database - 2. embedding_search is called via ThreadPoolExecutor + 2. _retrieve is called via ThreadPoolExecutor 3. Documents are added to shared all_documents list 4. Results are returned to caller @@ -447,28 +447,28 @@ class TestRetrievalService: # Set up the mock dataset that will be "retrieved" from database mock_get_dataset.return_value = mock_dataset - # Create a side effect function that simulates embedding_search behavior - # In the real implementation, embedding_search: - # 1. Gets the dataset - # 2. Creates a Vector instance - # 3. Calls search_by_vector with embeddings - # 4. Extends all_documents with results - def side_effect_embedding_search( + # Create a side effect function that simulates _retrieve behavior + # _retrieve modifies the all_documents list in place + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - """Simulate embedding_search adding documents to the shared list.""" - all_documents.extend(sample_documents) + """Simulate _retrieve adding documents to the shared list.""" + if all_documents is not None: + all_documents.extend(sample_documents) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve # Define test parameters query = "What is Python?" # Natural language query @@ -481,7 +481,7 @@ class TestRetrievalService: # 1. Check if query is empty (early return if so) # 2. Get the dataset using _get_dataset # 3. Create ThreadPoolExecutor - # 4. Submit embedding_search task + # 4. Submit _retrieve task # 5. Wait for completion # 6. Return all_documents list results = RetrievalService.retrieve( @@ -502,15 +502,13 @@ class TestRetrievalService: # Verify documents maintain their scores (highest score first in sample_documents) assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents" - # Verify embedding_search was called exactly once + # Verify _retrieve was called exactly once # This confirms the search method was invoked by ThreadPoolExecutor - mock_embedding_search.assert_called_once() + mock_retrieve.assert_called_once() - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_document_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with document ID filtering. @@ -522,21 +520,25 @@ class TestRetrievalService: mock_get_dataset.return_value = mock_dataset filtered_docs = [sample_documents[0]] - def side_effect_embedding_search( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(filtered_docs) + if all_documents is not None: + all_documents.extend(filtered_docs) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve document_ids_filter = [sample_documents[0].metadata["document_id"]] # Act @@ -552,12 +554,12 @@ class TestRetrievalService: assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" # Verify document_ids_filter was passed - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["document_ids_filter"] == document_ids_filter - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search when no results match the query. @@ -567,8 +569,8 @@ class TestRetrievalService: """ # Arrange mock_get_dataset.return_value = mock_dataset - # embedding_search doesn't add anything to all_documents - mock_embedding_search.side_effect = lambda *args, **kwargs: None + # _retrieve doesn't add anything to all_documents + mock_retrieve.side_effect = lambda *args, **kwargs: None # Act results = RetrievalService.retrieve( @@ -583,9 +585,9 @@ class TestRetrievalService: # ==================== Keyword Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents): + def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic keyword search functionality. @@ -597,12 +599,25 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - def side_effect_keyword_search( - flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None + def side_effect_retrieve( + flask_app, + retrieval_method, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(sample_documents) + if all_documents is not None: + all_documents.extend(sample_documents) - mock_keyword_search.side_effect = side_effect_keyword_search + mock_retrieve.side_effect = side_effect_retrieve query = "Python programming" top_k = 3 @@ -618,7 +633,7 @@ class TestRetrievalService: # Assert assert len(results) == 3 assert all(isinstance(doc, Document) for doc in results) - mock_keyword_search.assert_called_once() + mock_retrieve.assert_called_once() @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -1147,11 +1162,9 @@ class TestRetrievalService: # ==================== Metadata Filtering Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_metadata_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with metadata-based document filtering. @@ -1166,21 +1179,25 @@ class TestRetrievalService: filtered_doc = sample_documents[0] filtered_doc.metadata["category"] = "programming" - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(filtered_doc) + if all_documents is not None: + all_documents.append(filtered_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve # Act results = RetrievalService.retrieve( @@ -1243,9 +1260,9 @@ class TestRetrievalService: # Assert assert results == [] - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that exceptions during retrieval are properly handled. @@ -1256,22 +1273,26 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - # Make embedding_search add an exception to the exceptions list + # Make _retrieve add an exception to the exceptions list def side_effect_with_exception( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - exceptions.append("Search failed") + if exceptions is not None: + exceptions.append("Search failed") - mock_embedding_search.side_effect = side_effect_with_exception + mock_retrieve.side_effect = side_effect_with_exception # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1286,9 +1307,9 @@ class TestRetrievalService: # ==================== Score Threshold Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search with score threshold filtering. @@ -1306,21 +1327,25 @@ class TestRetrievalService: provider="dify", ) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(high_score_doc) + if all_documents is not None: + all_documents.append(high_score_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve score_threshold = 0.8 @@ -1339,9 +1364,9 @@ class TestRetrievalService: # ==================== Top-K Limiting Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that retrieval respects top_k parameter. @@ -1362,22 +1387,26 @@ class TestRetrievalService: for i in range(10) ] - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): # Return only top_k documents - all_documents.extend(many_docs[:top_k]) + if all_documents is not None: + all_documents.extend(many_docs[:top_k]) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve top_k = 3 @@ -1390,9 +1419,9 @@ class TestRetrievalService: ) # Assert - # Verify top_k was passed to embedding_search - assert mock_embedding_search.called - call_kwargs = mock_embedding_search.call_args.kwargs + # Verify _retrieve was called + assert mock_retrieve.called + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["top_k"] == top_k # Verify we got the right number of results assert len(results) == top_k @@ -1421,11 +1450,9 @@ class TestRetrievalService: # ==================== Reranking Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_semantic_search_with_reranking( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test semantic search with reranking model. @@ -1439,22 +1466,26 @@ class TestRetrievalService: # Simulate reranking changing order reranked_docs = list(reversed(sample_documents)) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - # embedding_search handles reranking internally - all_documents.extend(reranked_docs) + # _retrieve handles reranking internally + if all_documents is not None: + all_documents.extend(reranked_docs) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve reranking_model = { "reranking_provider_name": "cohere", @@ -1473,7 +1504,7 @@ class TestRetrievalService: # Assert # For semantic search with reranking, reranking_model should be passed assert len(results) == 3 - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 8bfc97ae63..8af47e8967 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -8,7 +8,9 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols [ ("...Hello, World!", "Hello, World!"), ("。测试中文标点", "测试中文标点"), - ("!@#Test symbols", "Test symbols"), + # Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols" + # The pattern intentionally excludes ! as per #11868 fix + ("@#Test symbols", "Test symbols"), ("Hello, World!", "Hello, World!"), ("", ""), (" ", " "), diff --git a/docker/.env.example b/docker/.env.example index b71c38e07a..85e8b1dc7f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -808,6 +808,19 @@ UPLOAD_FILE_BATCH_LIMIT=5 # Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll UPLOAD_FILE_EXTENSION_BLACKLIST= +# Maximum number of files allowed in a single chunk attachment, default 10. +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 + +# Maximum number of files allowed in a image batch upload operation +IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed image file size for attachments in megabytes, default 2. +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 + +# Timeout for downloading image attachments in seconds, default 60. +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 + + # ETL type, support: `dify`, `Unstructured` # `dify` Dify's proprietary file extraction scheme # `Unstructured` Unstructured.io file extraction scheme @@ -1116,6 +1129,9 @@ WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai WEAVIATE_DISABLE_TELEMETRY=false +WEAVIATE_ENABLE_TOKENIZER_GSE=false +WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA=false +WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR=false # ------------------------------ # Environment Variables for Chroma @@ -1415,4 +1431,4 @@ WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100 WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration -TENANT_ISOLATED_TASK_CONCURRENCY=1 \ No newline at end of file +TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index f1061ef5f9..3c01274ce8 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -451,6 +451,9 @@ services: AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} # OceanBase vector database oceanbase: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7ae8a70699..809aa1f841 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,6 +364,10 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} + SINGLE_CHUNK_ATTACHMENT_LIMIT: ${SINGLE_CHUNK_ATTACHMENT_LIMIT:-10} + IMAGE_FILE_BATCH_LIMIT: ${IMAGE_FILE_BATCH_LIMIT:-10} + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: ${ATTACHMENT_IMAGE_FILE_SIZE_LIMIT:-2} + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: ${ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT:-60} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} @@ -475,6 +479,9 @@ x-shared-env: &shared-api-worker-env WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} WEAVIATE_AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} WEAVIATE_DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + WEAVIATE_ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} CHROMA_IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} @@ -1081,6 +1088,9 @@ services: AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} # OceanBase vector database oceanbase: diff --git a/sdks/python-client/LICENSE b/sdks/python-client/LICENSE deleted file mode 100644 index 873e44b4bc..0000000000 --- a/sdks/python-client/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 LangGenius - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/sdks/python-client/MANIFEST.in b/sdks/python-client/MANIFEST.in deleted file mode 100644 index 34b7e8711c..0000000000 --- a/sdks/python-client/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include dify_client *.py -include README.md -include LICENSE diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md deleted file mode 100644 index ebfb5f5397..0000000000 --- a/sdks/python-client/README.md +++ /dev/null @@ -1,409 +0,0 @@ -# dify-client - -A Dify App Service-API Client, using for build a webapp by request Service-API - -## Usage - -First, install `dify-client` python sdk package: - -``` -pip install dify-client -``` - -### Synchronous Usage - -Write your code with sdk: - -- completion generate with `blocking` response_mode - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "What's the weather like today?"}, - response_mode="blocking", user="user_id") -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- completion using vision model, like gpt-4-vision - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "Describe the picture."}, - response_mode="blocking", user="user_id", files=files) -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- chat generate with `streaming` response_mode - -```python -import json -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Hello", user="user_id", response_mode="streaming") -chat_response.raise_for_status() - -for line in chat_response.iter_lines(decode_unicode=True): - line = line.split('data:', 1)[-1] - if line.strip(): - line = json.loads(line.strip()) - print(line.get('answer')) -``` - -- chat using vision model, like gpt-4-vision - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Describe the picture.", user="user_id", - response_mode="blocking", files=files) -chat_response.raise_for_status() - -result = chat_response.json() - -print(result.get("answer")) -``` - -- upload file when using vision model - -```python -from dify_client import DifyClient - -api_key = "your_api_key" - -# Initialize Client -dify_client = DifyClient(api_key) - -file_path = "your_image_file_path" -file_name = "panda.jpeg" -mime_type = "image/jpeg" - -with open(file_path, "rb") as file: - files = { - "file": (file_name, file, mime_type) - } - response = dify_client.file_upload("user_id", files) - - result = response.json() - print(f'upload_file_id: {result.get("id")}') -``` - -- Others - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize Client -client = ChatClient(api_key) - -# Get App parameters -parameters = client.get_application_parameters(user="user_id") -parameters.raise_for_status() - -print('[parameters]') -print(parameters.json()) - -# Get Conversation List (only for chat) -conversations = client.get_conversations(user="user_id") -conversations.raise_for_status() - -print('[conversations]') -print(conversations.json()) - -# Get Message List (only for chat) -messages = client.get_conversation_messages(user="user_id", conversation_id="conversation_id") -messages.raise_for_status() - -print('[messages]') -print(messages.json()) - -# Rename Conversation (only for chat) -rename_conversation_response = client.rename_conversation(conversation_id="conversation_id", - name="new_name", user="user_id") -rename_conversation_response.raise_for_status() - -print('[rename result]') -print(rename_conversation_response.json()) -``` - -- Using the Workflow Client - -```python -import json -import requests -from dify_client import WorkflowClient - -api_key = "your_api_key" - -# Initialize Workflow Client -client = WorkflowClient(api_key) - -# Prepare parameters for Workflow Client -user_id = "your_user_id" -context = "previous user interaction / metadata" -user_prompt = "What is the capital of France?" - -inputs = { - "context": context, - "user_prompt": user_prompt, - # Add other input fields expected by your workflow (e.g., additional context, task parameters) - -} - -# Set response mode (default: streaming) -response_mode = "blocking" - -# Run the workflow -response = client.run(inputs=inputs, response_mode=response_mode, user=user_id) -response.raise_for_status() - -# Parse result -result = json.loads(response.text) - -answer = result.get("data").get("outputs") - -print(answer["answer"]) - -``` - -- Dataset Management - -```python -from dify_client import KnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -# Use context manager to ensure proper resource cleanup -with KnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # Update dataset configuration - update_response = kb_client.update_dataset( - name="Updated Dataset Name", - description="Updated description", - indexing_technique="high_quality" - ) - update_response.raise_for_status() - print(update_response.json()) - - # Batch update document status - batch_response = kb_client.batch_update_document_status( - action="enable", - document_ids=["doc_id_1", "doc_id_2", "doc_id_3"] - ) - batch_response.raise_for_status() - print(batch_response.json()) -``` - -- Conversation Variables Management - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Use context manager to ensure proper resource cleanup -with ChatClient(api_key) as chat_client: - # Get all conversation variables - variables = chat_client.get_conversation_variables( - conversation_id="conversation_id", - user="user_id" - ) - variables.raise_for_status() - print(variables.json()) - - # Update a specific conversation variable - update_var = chat_client.update_conversation_variable( - conversation_id="conversation_id", - variable_id="variable_id", - value="new_value", - user="user_id" - ) - update_var.raise_for_status() - print(update_var.json()) -``` - -### Asynchronous Usage - -The SDK provides full async/await support for all API operations using `httpx.AsyncClient`. All async clients mirror their synchronous counterparts but require `await` for method calls. - -- async chat with `blocking` response_mode - -```python -import asyncio -from dify_client import AsyncChatClient - -api_key = "your_api_key" - -async def main(): - # Use async context manager for proper resource cleanup - async with AsyncChatClient(api_key) as client: - response = await client.create_chat_message( - inputs={}, - query="Hello, how are you?", - user="user_id", - response_mode="blocking" - ) - response.raise_for_status() - result = response.json() - print(result.get('answer')) - -# Run the async function -asyncio.run(main()) -``` - -- async completion with `streaming` response_mode - -```python -import asyncio -import json -from dify_client import AsyncCompletionClient - -api_key = "your_api_key" - -async def main(): - async with AsyncCompletionClient(api_key) as client: - response = await client.create_completion_message( - inputs={"query": "What's the weather?"}, - response_mode="streaming", - user="user_id" - ) - response.raise_for_status() - - # Stream the response - async for line in response.aiter_lines(): - if line.startswith('data:'): - data = line[5:].strip() - if data: - chunk = json.loads(data) - print(chunk.get('answer', ''), end='', flush=True) - -asyncio.run(main()) -``` - -- async workflow execution - -```python -import asyncio -from dify_client import AsyncWorkflowClient - -api_key = "your_api_key" - -async def main(): - async with AsyncWorkflowClient(api_key) as client: - response = await client.run( - inputs={"query": "What is machine learning?"}, - response_mode="blocking", - user="user_id" - ) - response.raise_for_status() - result = response.json() - print(result.get("data").get("outputs")) - -asyncio.run(main()) -``` - -- async dataset management - -```python -import asyncio -from dify_client import AsyncKnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -async def main(): - async with AsyncKnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = await kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # List documents - docs = await kb_client.list_documents(page=1, page_size=10) - docs.raise_for_status() - print(docs.json()) - -asyncio.run(main()) -``` - -**Benefits of Async Usage:** - -- **Better Performance**: Handle multiple concurrent API requests efficiently -- **Non-blocking I/O**: Don't block the event loop during network operations -- **Scalability**: Ideal for applications handling many simultaneous requests -- **Modern Python**: Leverages Python's native async/await syntax - -**Available Async Clients:** - -- `AsyncDifyClient` - Base async client -- `AsyncChatClient` - Async chat operations -- `AsyncCompletionClient` - Async completion operations -- `AsyncWorkflowClient` - Async workflow operations -- `AsyncKnowledgeBaseClient` - Async dataset/knowledge base operations -- `AsyncWorkspaceClient` - Async workspace operations - -``` -``` diff --git a/sdks/python-client/build.sh b/sdks/python-client/build.sh deleted file mode 100755 index 525f57c1ef..0000000000 --- a/sdks/python-client/build.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -set -e - -rm -rf build dist *.egg-info - -pip install setuptools wheel twine -python setup.py sdist bdist_wheel -twine upload dist/* diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py deleted file mode 100644 index ced093b20a..0000000000 --- a/sdks/python-client/dify_client/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, - WorkflowClient, - WorkspaceClient, -) - -from dify_client.async_client import ( - AsyncChatClient, - AsyncCompletionClient, - AsyncDifyClient, - AsyncKnowledgeBaseClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, -) - -__all__ = [ - # Synchronous clients - "ChatClient", - "CompletionClient", - "DifyClient", - "KnowledgeBaseClient", - "WorkflowClient", - "WorkspaceClient", - # Asynchronous clients - "AsyncChatClient", - "AsyncCompletionClient", - "AsyncDifyClient", - "AsyncKnowledgeBaseClient", - "AsyncWorkflowClient", - "AsyncWorkspaceClient", -] diff --git a/sdks/python-client/dify_client/async_client.py b/sdks/python-client/dify_client/async_client.py deleted file mode 100644 index 23126cf326..0000000000 --- a/sdks/python-client/dify_client/async_client.py +++ /dev/null @@ -1,2074 +0,0 @@ -"""Asynchronous Dify API client. - -This module provides async/await support for all Dify API operations using httpx.AsyncClient. -All client classes mirror their synchronous counterparts but require `await` for method calls. - -Example: - import asyncio - from dify_client import AsyncChatClient - - async def main(): - async with AsyncChatClient(api_key="your-key") as client: - response = await client.create_chat_message( - inputs={}, - query="Hello", - user="user-123" - ) - print(response.json()) - - asyncio.run(main()) -""" - -import json -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import aiofiles -import httpx - - -class AsyncDifyClient: - """Asynchronous Dify API client. - - This client uses httpx.AsyncClient for efficient async connection pooling. - It's recommended to use this client as a context manager: - - Example: - async with AsyncDifyClient(api_key="your-key") as client: - response = await client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - ): - """Initialize the async Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - """ - self.api_key = api_key - self.base_url = base_url - self._client = httpx.AsyncClient( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - async def __aenter__(self): - """Support async context manager protocol.""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting async context.""" - await self.aclose() - - async def aclose(self): - """Close the async HTTP client and release resources.""" - if hasattr(self, "_client"): - await self._client.aclose() - - async def _send_request( - self, - method: str, - endpoint: str, - json: Dict | None = None, - params: Dict | None = None, - stream: bool = False, - **kwargs, - ): - """Send an async HTTP request to the Dify API. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - response = await self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - return response - - async def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an async HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - response = await self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - return response - - async def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - """Send feedback for a message.""" - data = {"rating": rating, "user": user} - return await self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - async def get_application_parameters(self, user: str): - """Get application parameters.""" - params = {"user": user} - return await self._send_request("GET", "/parameters", params=params) - - async def file_upload(self, user: str, files: dict): - """Upload a file.""" - data = {"user": user} - return await self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - async def text_to_audio(self, text: str, user: str, streaming: bool = False): - """Convert text to audio.""" - data = {"text": text, "user": user, "streaming": streaming} - return await self._send_request("POST", "/text-to-audio", json=data) - - async def get_meta(self, user: str): - """Get metadata.""" - params = {"user": user} - return await self._send_request("GET", "/meta", params=params) - - async def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return await self._send_request("GET", "/info") - - async def get_app_site_info(self): - """Get application site information.""" - return await self._send_request("GET", "/site") - - async def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return await self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - async def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("GET", url) - - async def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("PUT", url, json=config_data) - - async def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("GET", url) - - async def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("POST", url, json=data) - - async def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return await self._send_request("DELETE", url) - - -class AsyncCompletionClient(AsyncDifyClient): - """Async client for Completion API operations.""" - - async def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict | None = None, - ): - """Create a completion message. - - Args: - inputs: Input variables for the completion - response_mode: Response mode ('blocking' or 'streaming') - user: User identifier - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return await self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class AsyncChatClient(AsyncDifyClient): - """Async client for Chat API operations.""" - - async def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict | None = None, - ): - """Create a chat message. - - Args: - inputs: Input variables for the chat - query: User query/message - user: User identifier - response_mode: Response mode ('blocking' or 'streaming') - conversation_id: Optional conversation ID for context - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return await self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - async def get_suggested(self, message_id: str, user: str): - """Get suggested questions for a message.""" - params = {"user": user} - return await self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - async def stop_message(self, task_id: str, user: str): - """Stop a running message generation.""" - data = {"user": user} - return await self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - async def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - """Get list of conversations.""" - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return await self._send_request("GET", "/conversations", params=params) - - async def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - """Get messages from a conversation.""" - params = { - "user": user, - "conversation_id": conversation_id, - "first_id": first_id, - "limit": limit, - } - return await self._send_request("GET", "/messages", params=params) - - async def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - """Rename a conversation.""" - data = {"name": name, "auto_generate": auto_generate, "user": user} - return await self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - async def delete_conversation(self, conversation_id: str, user: str): - """Delete a conversation.""" - data = {"user": user} - return await self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - async def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - """Convert audio to text.""" - data = {"user": user} - files = {"file": audio_file} - return await self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - async def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return await self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - async def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return await self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - async def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - async def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return await self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Enhanced Annotation APIs - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - # Conversation Variables APIs - async def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PATCH", url, json=data) - - # Enhanced Conversation Variable APIs - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, data=data) - - # Additional annotation methods for API parity - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - -class AsyncWorkflowClient(AsyncDifyClient): - """Async client for Workflow API operations.""" - - async def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a workflow.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request("POST", "/workflows/run", data) - - async def stop(self, task_id: str, user: str): - """Stop a running workflow task.""" - data = {"user": user} - return await self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - async def get_result(self, workflow_run_id: str): - """Get workflow run result.""" - return await self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - async def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = { - "page": page, - "limit": limit, - "keyword": keyword, - "status": status, - "created_at__before": created_at__before, - "created_at__after": created_at__after, - "created_by_end_user_session_id": created_by_end_user_session_id, - "created_by_account": created_by_account, - } - return await self._send_request("GET", "/workflows/logs", params=params) - - async def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - async def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("GET", url) - - async def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("PUT", url, json=workflow_data) - - async def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return await self._send_request("POST", url) - - async def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return await self._send_request("GET", url, params=params) - - -class AsyncWorkspaceClient(AsyncDifyClient): - """Async client for workspace-related operations.""" - - async def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_model_providers(self): - """Get all model providers.""" - return await self._send_request("GET", "/workspaces/current/model-providers") - - async def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return await self._send_request("GET", url) - - async def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return await self._send_request("POST", url, json=credentials) - - # File Management APIs - async def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return await self._send_request("GET", url) - - async def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return await self._send_request("GET", url) - - async def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return await self._send_request("DELETE", url) - - -class AsyncKnowledgeBaseClient(AsyncDifyClient): - """Async client for Knowledge Base API operations.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - timeout: float = 60.0, - ): - """Construct an AsyncKnowledgeBaseClient object. - - Args: - api_key: API key of Dify - base_url: Base URL of Dify API - dataset_id: ID of the dataset - timeout: Request timeout in seconds - """ - super().__init__(api_key=api_key, base_url=base_url, timeout=timeout) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - """Get the dataset ID, raise error if not set.""" - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - async def create_dataset(self, name: str, **kwargs): - """Create a new dataset.""" - return await self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - async def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - """List all datasets.""" - return await self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - async def create_document_by_text(self, name: str, text: str, extra_params: Dict | None = None, **kwargs): - """Create a document by text. - - Args: - name: Name of the document - text: Text content of the document - extra_params: Extra parameters for the API - - Returns: - Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict | None = None, - **kwargs, - ): - """Update a document by text.""" - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict | None = None, - ): - """Create a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def update_document_by_file(self, document_id: str, file_path: str, extra_params: Dict | None = None): - """Update a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def batch_indexing_status(self, batch_id: str, **kwargs): - """Get the status of the batch indexing.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return await self._send_request("GET", url, **kwargs) - - async def delete_dataset(self): - """Delete this dataset.""" - url = f"/datasets/{self._get_dataset_id()}" - return await self._send_request("DELETE", url) - - async def delete_document(self, document_id: str): - """Delete a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return await self._send_request("DELETE", url) - - async def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """Get a list of documents in this dataset.""" - params = { - "page": page, - "limit": page_size, - "keyword": keyword, - } - url = f"/datasets/{self._get_dataset_id()}/documents" - return await self._send_request("GET", url, params=params, **kwargs) - - async def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """Add segments to a document.""" - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return await self._send_request("POST", url, json=data, **kwargs) - - async def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """Query segments in this document. - - Args: - document_id: ID of the document - keyword: Query keyword (optional) - status: Status of the segment (optional, e.g., 'completed') - **kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - - Returns: - Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = { - "keyword": keyword, - "status": status, - } - if "params" in kwargs: - params.update(kwargs.pop("params")) - return await self._send_request("GET", url, params=params, **kwargs) - - async def delete_document_segment(self, document_id: str, segment_id: str): - """Delete a segment from a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("DELETE", url) - - async def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """Update a segment in a document.""" - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - async def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return await self._send_request("POST", url, json=data) - - async def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("GET", url) - - async def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("POST", url, json=metadata_data) - - async def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return await self._send_request("PATCH", url, json=metadata_data) - - async def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return await self._send_request("GET", url) - - async def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return await self._send_request("POST", url, json=data) - - async def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return await self._send_request("POST", url, json=data) - - # Dataset Tags APIs - async def list_dataset_tags(self): - """List all dataset tags.""" - return await self._send_request("GET", "/datasets/tags") - - async def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/binding", json=data) - - async def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/unbinding", json=data) - - async def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return await self._send_request("GET", url) - - # RAG Pipeline APIs - async def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return await self._send_request("GET", url, params=params) - - async def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return await self._send_request("POST", url, json=data, stream=True) - - async def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return await self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - async def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return await self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - async def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return await self._send_request("GET", url) - - async def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return await self._send_request("PATCH", url, json=data) - - async def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return await self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - - async def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return await self._send_request("POST", "/datasets/from-template", json=data) - - async def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return await self._send_request("POST", url, json=data) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, json=data) - - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - -class AsyncEnterpriseClient(AsyncDifyClient): - """Async Enterprise and Account Management APIs for Dify platform administration.""" - - async def get_account_info(self): - """Get current account information.""" - return await self._send_request("GET", "/account") - - async def update_account_info(self, account_data: Dict[str, Any]): - """Update account information.""" - return await self._send_request("PUT", "/account", json=account_data) - - # Member Management APIs - async def list_members(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List workspace members with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/members", params=params) - - async def invite_member(self, email: str, role: str, name: str | None = None): - """Invite a new member to the workspace.""" - data = {"email": email, "role": role} - if name: - data["name"] = name - return await self._send_request("POST", "/members/invite", json=data) - - async def get_member(self, member_id: str): - """Get detailed information about a specific member.""" - url = f"/members/{member_id}" - return await self._send_request("GET", url) - - async def update_member(self, member_id: str, member_data: Dict[str, Any]): - """Update member information.""" - url = f"/members/{member_id}" - return await self._send_request("PUT", url, json=member_data) - - async def remove_member(self, member_id: str): - """Remove a member from the workspace.""" - url = f"/members/{member_id}" - return await self._send_request("DELETE", url) - - async def deactivate_member(self, member_id: str): - """Deactivate a member account.""" - url = f"/members/{member_id}/deactivate" - return await self._send_request("POST", url) - - async def reactivate_member(self, member_id: str): - """Reactivate a deactivated member account.""" - url = f"/members/{member_id}/reactivate" - return await self._send_request("POST", url) - - # Role Management APIs - async def list_roles(self): - """List all available roles in the workspace.""" - return await self._send_request("GET", "/roles") - - async def create_role(self, name: str, description: str, permissions: List[str]): - """Create a new role with specified permissions.""" - data = {"name": name, "description": description, "permissions": permissions} - return await self._send_request("POST", "/roles", json=data) - - async def get_role(self, role_id: str): - """Get detailed information about a specific role.""" - url = f"/roles/{role_id}" - return await self._send_request("GET", url) - - async def update_role(self, role_id: str, role_data: Dict[str, Any]): - """Update role information.""" - url = f"/roles/{role_id}" - return await self._send_request("PUT", url, json=role_data) - - async def delete_role(self, role_id: str): - """Delete a role.""" - url = f"/roles/{role_id}" - return await self._send_request("DELETE", url) - - # Permission Management APIs - async def list_permissions(self): - """List all available permissions.""" - return await self._send_request("GET", "/permissions") - - async def get_role_permissions(self, role_id: str): - """Get permissions for a specific role.""" - url = f"/roles/{role_id}/permissions" - return await self._send_request("GET", url) - - async def update_role_permissions(self, role_id: str, permissions: List[str]): - """Update permissions for a role.""" - url = f"/roles/{role_id}/permissions" - data = {"permissions": permissions} - return await self._send_request("PUT", url, json=data) - - # Workspace Settings APIs - async def get_workspace_settings(self): - """Get workspace settings and configuration.""" - return await self._send_request("GET", "/workspace/settings") - - async def update_workspace_settings(self, settings_data: Dict[str, Any]): - """Update workspace settings.""" - return await self._send_request("PUT", "/workspace/settings", json=settings_data) - - async def get_workspace_statistics(self): - """Get workspace usage statistics.""" - return await self._send_request("GET", "/workspace/statistics") - - # Billing and Subscription APIs - async def get_billing_info(self): - """Get current billing information.""" - return await self._send_request("GET", "/billing") - - async def get_subscription_info(self): - """Get current subscription information.""" - return await self._send_request("GET", "/subscription") - - async def update_subscription(self, subscription_data: Dict[str, Any]): - """Update subscription settings.""" - return await self._send_request("PUT", "/subscription", json=subscription_data) - - async def get_billing_history(self, page: int = 1, limit: int = 20): - """Get billing history with pagination.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/billing/history", params=params) - - async def get_usage_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get usage metrics for a date range.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/usage/metrics", params=params) - - # Audit Logs APIs - async def get_audit_logs( - self, - page: int = 1, - limit: int = 20, - action: str | None = None, - user_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get audit logs with filtering options.""" - params = {"page": page, "limit": limit} - if action: - params["action"] = action - if user_id: - params["user_id"] = user_id - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/audit/logs", params=params) - - async def export_audit_logs(self, format: str = "csv", filters: Dict[str, Any] | None = None): - """Export audit logs in specified format.""" - params = {"format": format} - if filters: - params.update(filters) - return await self._send_request("GET", "/audit/logs/export", params=params) - - -class AsyncSecurityClient(AsyncDifyClient): - """Async Security and Access Control APIs for Dify platform security management.""" - - # API Key Management APIs - async def list_api_keys(self, page: int = 1, limit: int = 20, status: str | None = None): - """List all API keys with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/security/api-keys", params=params) - - async def create_api_key( - self, - name: str, - permissions: List[str], - expires_at: str | None = None, - description: str | None = None, - ): - """Create a new API key with specified permissions.""" - data = {"name": name, "permissions": permissions} - if expires_at: - data["expires_at"] = expires_at - if description: - data["description"] = description - return await self._send_request("POST", "/security/api-keys", json=data) - - async def get_api_key(self, key_id: str): - """Get detailed information about an API key.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("GET", url) - - async def update_api_key(self, key_id: str, key_data: Dict[str, Any]): - """Update API key information.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("PUT", url, json=key_data) - - async def revoke_api_key(self, key_id: str): - """Revoke an API key.""" - url = f"/security/api-keys/{key_id}/revoke" - return await self._send_request("POST", url) - - async def rotate_api_key(self, key_id: str): - """Rotate an API key (generate new key).""" - url = f"/security/api-keys/{key_id}/rotate" - return await self._send_request("POST", url) - - # Rate Limiting APIs - async def get_rate_limits(self): - """Get current rate limiting configuration.""" - return await self._send_request("GET", "/security/rate-limits") - - async def update_rate_limits(self, limits_config: Dict[str, Any]): - """Update rate limiting configuration.""" - return await self._send_request("PUT", "/security/rate-limits", json=limits_config) - - async def get_rate_limit_usage(self, timeframe: str = "1h"): - """Get rate limit usage statistics.""" - params = {"timeframe": timeframe} - return await self._send_request("GET", "/security/rate-limits/usage", params=params) - - # Access Control Lists APIs - async def list_access_policies(self, page: int = 1, limit: int = 20): - """List access control policies.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/security/access-policies", params=params) - - async def create_access_policy(self, policy_data: Dict[str, Any]): - """Create a new access control policy.""" - return await self._send_request("POST", "/security/access-policies", json=policy_data) - - async def get_access_policy(self, policy_id: str): - """Get detailed information about an access policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("GET", url) - - async def update_access_policy(self, policy_id: str, policy_data: Dict[str, Any]): - """Update an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("PUT", url, json=policy_data) - - async def delete_access_policy(self, policy_id: str): - """Delete an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("DELETE", url) - - # Security Settings APIs - async def get_security_settings(self): - """Get security configuration settings.""" - return await self._send_request("GET", "/security/settings") - - async def update_security_settings(self, settings_data: Dict[str, Any]): - """Update security configuration settings.""" - return await self._send_request("PUT", "/security/settings", json=settings_data) - - async def get_security_audit_logs( - self, - page: int = 1, - limit: int = 20, - event_type: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get security-specific audit logs.""" - params = {"page": page, "limit": limit} - if event_type: - params["event_type"] = event_type - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/security/audit-logs", params=params) - - # IP Whitelist/Blacklist APIs - async def get_ip_whitelist(self): - """Get IP whitelist configuration.""" - return await self._send_request("GET", "/security/ip-whitelist") - - async def update_ip_whitelist(self, ip_list: List[str], description: str | None = None): - """Update IP whitelist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-whitelist", json=data) - - async def get_ip_blacklist(self): - """Get IP blacklist configuration.""" - return await self._send_request("GET", "/security/ip-blacklist") - - async def update_ip_blacklist(self, ip_list: List[str], description: str | None = None): - """Update IP blacklist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-blacklist", json=data) - - # Authentication Settings APIs - async def get_auth_settings(self): - """Get authentication configuration settings.""" - return await self._send_request("GET", "/security/auth-settings") - - async def update_auth_settings(self, auth_data: Dict[str, Any]): - """Update authentication configuration settings.""" - return await self._send_request("PUT", "/security/auth-settings", json=auth_data) - - async def test_auth_configuration(self, auth_config: Dict[str, Any]): - """Test authentication configuration.""" - return await self._send_request("POST", "/security/auth-settings/test", json=auth_config) - - -class AsyncAnalyticsClient(AsyncDifyClient): - """Async Analytics and Monitoring APIs for Dify platform insights and metrics.""" - - # Usage Analytics APIs - async def get_usage_analytics( - self, - start_date: str, - end_date: str, - granularity: str = "day", - metrics: List[str] | None = None, - ): - """Get usage analytics for specified date range.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - if metrics: - params["metrics"] = ",".join(metrics) - return await self._send_request("GET", "/analytics/usage", params=params) - - async def get_app_usage_analytics(self, app_id: str, start_date: str, end_date: str, granularity: str = "day"): - """Get usage analytics for a specific app.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/analytics/apps/{app_id}/usage" - return await self._send_request("GET", url, params=params) - - async def get_user_analytics(self, start_date: str, end_date: str, user_segment: str | None = None): - """Get user analytics and behavior insights.""" - params = {"start_date": start_date, "end_date": end_date} - if user_segment: - params["user_segment"] = user_segment - return await self._send_request("GET", "/analytics/users", params=params) - - # Performance Metrics APIs - async def get_performance_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get performance metrics for the platform.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/analytics/performance", params=params) - - async def get_app_performance_metrics(self, app_id: str, start_date: str, end_date: str): - """Get performance metrics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/performance" - return await self._send_request("GET", url, params=params) - - async def get_model_performance_metrics(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get performance metrics for a specific model.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/models/{model_provider}/{model_name}/performance" - return await self._send_request("GET", url, params=params) - - # Cost Tracking APIs - async def get_cost_analytics(self, start_date: str, end_date: str, cost_type: str | None = None): - """Get cost analytics and breakdown.""" - params = {"start_date": start_date, "end_date": end_date} - if cost_type: - params["cost_type"] = cost_type - return await self._send_request("GET", "/analytics/costs", params=params) - - async def get_app_cost_analytics(self, app_id: str, start_date: str, end_date: str): - """Get cost analytics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/costs" - return await self._send_request("GET", url, params=params) - - async def get_cost_forecast(self, forecast_period: str = "30d"): - """Get cost forecast for specified period.""" - params = {"forecast_period": forecast_period} - return await self._send_request("GET", "/analytics/costs/forecast", params=params) - - # Real-time Monitoring APIs - async def get_real_time_metrics(self): - """Get real-time platform metrics.""" - return await self._send_request("GET", "/analytics/realtime") - - async def get_app_real_time_metrics(self, app_id: str): - """Get real-time metrics for a specific app.""" - url = f"/analytics/apps/{app_id}/realtime" - return await self._send_request("GET", url) - - async def get_system_health(self): - """Get overall system health status.""" - return await self._send_request("GET", "/analytics/health") - - # Custom Reports APIs - async def create_custom_report(self, report_config: Dict[str, Any]): - """Create a custom analytics report.""" - return await self._send_request("POST", "/analytics/reports", json=report_config) - - async def list_custom_reports(self, page: int = 1, limit: int = 20): - """List custom analytics reports.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/analytics/reports", params=params) - - async def get_custom_report(self, report_id: str): - """Get a specific custom report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("GET", url) - - async def update_custom_report(self, report_id: str, report_config: Dict[str, Any]): - """Update a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("PUT", url, json=report_config) - - async def delete_custom_report(self, report_id: str): - """Delete a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("DELETE", url) - - async def generate_report(self, report_id: str, format: str = "pdf"): - """Generate and download a custom report.""" - params = {"format": format} - url = f"/analytics/reports/{report_id}/generate" - return await self._send_request("GET", url, params=params) - - # Export APIs - async def export_analytics_data(self, data_type: str, start_date: str, end_date: str, format: str = "csv"): - """Export analytics data in specified format.""" - params = { - "data_type": data_type, - "start_date": start_date, - "end_date": end_date, - "format": format, - } - return await self._send_request("GET", "/analytics/export", params=params) - - -class AsyncIntegrationClient(AsyncDifyClient): - """Async Integration and Plugin APIs for Dify platform extensibility.""" - - # Webhook Management APIs - async def list_webhooks(self, page: int = 1, limit: int = 20, status: str | None = None): - """List webhooks with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/integrations/webhooks", params=params) - - async def create_webhook(self, webhook_data: Dict[str, Any]): - """Create a new webhook.""" - return await self._send_request("POST", "/integrations/webhooks", json=webhook_data) - - async def get_webhook(self, webhook_id: str): - """Get detailed information about a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("GET", url) - - async def update_webhook(self, webhook_id: str, webhook_data: Dict[str, Any]): - """Update webhook configuration.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("PUT", url, json=webhook_data) - - async def delete_webhook(self, webhook_id: str): - """Delete a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("DELETE", url) - - async def test_webhook(self, webhook_id: str): - """Test webhook delivery.""" - url = f"/integrations/webhooks/{webhook_id}/test" - return await self._send_request("POST", url) - - async def get_webhook_logs(self, webhook_id: str, page: int = 1, limit: int = 20): - """Get webhook delivery logs.""" - params = {"page": page, "limit": limit} - url = f"/integrations/webhooks/{webhook_id}/logs" - return await self._send_request("GET", url, params=params) - - # Plugin Management APIs - async def list_plugins(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available plugins.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/integrations/plugins", params=params) - - async def install_plugin(self, plugin_id: str, config: Dict[str, Any] | None = None): - """Install a plugin.""" - data = {"plugin_id": plugin_id} - if config: - data["config"] = config - return await self._send_request("POST", "/integrations/plugins/install", json=data) - - async def get_installed_plugin(self, installation_id: str): - """Get information about an installed plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("GET", url) - - async def update_plugin_config(self, installation_id: str, config: Dict[str, Any]): - """Update plugin configuration.""" - url = f"/integrations/plugins/{installation_id}/config" - return await self._send_request("PUT", url, json=config) - - async def uninstall_plugin(self, installation_id: str): - """Uninstall a plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("DELETE", url) - - async def enable_plugin(self, installation_id: str): - """Enable a plugin.""" - url = f"/integrations/plugins/{installation_id}/enable" - return await self._send_request("POST", url) - - async def disable_plugin(self, installation_id: str): - """Disable a plugin.""" - url = f"/integrations/plugins/{installation_id}/disable" - return await self._send_request("POST", url) - - # Import/Export APIs - async def export_app_data(self, app_id: str, format: str = "json", include_data: bool = True): - """Export application data.""" - params = {"format": format, "include_data": include_data} - url = f"/integrations/export/apps/{app_id}" - return await self._send_request("GET", url, params=params) - - async def import_app_data(self, import_data: Dict[str, Any]): - """Import application data.""" - return await self._send_request("POST", "/integrations/import/apps", json=import_data) - - async def get_import_status(self, import_id: str): - """Get import operation status.""" - url = f"/integrations/import/{import_id}/status" - return await self._send_request("GET", url) - - async def export_workspace_data(self, format: str = "json", include_data: bool = True): - """Export workspace data.""" - params = {"format": format, "include_data": include_data} - return await self._send_request("GET", "/integrations/export/workspace", params=params) - - async def import_workspace_data(self, import_data: Dict[str, Any]): - """Import workspace data.""" - return await self._send_request("POST", "/integrations/import/workspace", json=import_data) - - # Backup and Restore APIs - async def create_backup(self, backup_config: Dict[str, Any] | None = None): - """Create a system backup.""" - data = backup_config or {} - return await self._send_request("POST", "/integrations/backup/create", json=data) - - async def list_backups(self, page: int = 1, limit: int = 20): - """List available backups.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/integrations/backup", params=params) - - async def get_backup(self, backup_id: str): - """Get backup information.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("GET", url) - - async def restore_backup(self, backup_id: str, restore_config: Dict[str, Any] | None = None): - """Restore from backup.""" - data = restore_config or {} - url = f"/integrations/backup/{backup_id}/restore" - return await self._send_request("POST", url, json=data) - - async def delete_backup(self, backup_id: str): - """Delete a backup.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedModelClient(AsyncDifyClient): - """Async Advanced Model Management APIs for fine-tuning and custom deployments.""" - - # Fine-tuning Job Management APIs - async def list_fine_tuning_jobs( - self, - page: int = 1, - limit: int = 20, - status: str | None = None, - model_provider: str | None = None, - ): - """List fine-tuning jobs with filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - if model_provider: - params["model_provider"] = model_provider - return await self._send_request("GET", "/models/fine-tuning/jobs", params=params) - - async def create_fine_tuning_job(self, job_config: Dict[str, Any]): - """Create a new fine-tuning job.""" - return await self._send_request("POST", "/models/fine-tuning/jobs", json=job_config) - - async def get_fine_tuning_job(self, job_id: str): - """Get fine-tuning job details.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("GET", url) - - async def update_fine_tuning_job(self, job_id: str, job_config: Dict[str, Any]): - """Update fine-tuning job configuration.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("PUT", url, json=job_config) - - async def cancel_fine_tuning_job(self, job_id: str): - """Cancel a fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/cancel" - return await self._send_request("POST", url) - - async def resume_fine_tuning_job(self, job_id: str): - """Resume a paused fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/resume" - return await self._send_request("POST", url) - - async def get_fine_tuning_job_metrics(self, job_id: str): - """Get fine-tuning job training metrics.""" - url = f"/models/fine-tuning/jobs/{job_id}/metrics" - return await self._send_request("GET", url) - - async def get_fine_tuning_job_logs(self, job_id: str, page: int = 1, limit: int = 50): - """Get fine-tuning job logs.""" - params = {"page": page, "limit": limit} - url = f"/models/fine-tuning/jobs/{job_id}/logs" - return await self._send_request("GET", url, params=params) - - # Custom Model Deployment APIs - async def list_custom_deployments(self, page: int = 1, limit: int = 20, status: str | None = None): - """List custom model deployments.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/models/custom/deployments", params=params) - - async def create_custom_deployment(self, deployment_config: Dict[str, Any]): - """Create a custom model deployment.""" - return await self._send_request("POST", "/models/custom/deployments", json=deployment_config) - - async def get_custom_deployment(self, deployment_id: str): - """Get custom deployment details.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("GET", url) - - async def update_custom_deployment(self, deployment_id: str, deployment_config: Dict[str, Any]): - """Update custom deployment configuration.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("PUT", url, json=deployment_config) - - async def delete_custom_deployment(self, deployment_id: str): - """Delete a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("DELETE", url) - - async def scale_custom_deployment(self, deployment_id: str, scale_config: Dict[str, Any]): - """Scale custom deployment resources.""" - url = f"/models/custom/deployments/{deployment_id}/scale" - return await self._send_request("POST", url, json=scale_config) - - async def restart_custom_deployment(self, deployment_id: str): - """Restart a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}/restart" - return await self._send_request("POST", url) - - # Model Performance Monitoring APIs - async def get_model_performance_history( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get model performance history.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/models/{model_provider}/{model_name}/performance/history" - return await self._send_request("GET", url, params=params) - - async def get_model_health_metrics(self, model_provider: str, model_name: str): - """Get real-time model health metrics.""" - url = f"/models/{model_provider}/{model_name}/health" - return await self._send_request("GET", url) - - async def get_model_usage_stats( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - granularity: str = "day", - ): - """Get model usage statistics.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/models/{model_provider}/{model_name}/usage" - return await self._send_request("GET", url, params=params) - - async def get_model_cost_analysis(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get model cost analysis.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/models/{model_provider}/{model_name}/costs" - return await self._send_request("GET", url, params=params) - - # Model Versioning APIs - async def list_model_versions(self, model_provider: str, model_name: str, page: int = 1, limit: int = 20): - """List model versions.""" - params = {"page": page, "limit": limit} - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("GET", url, params=params) - - async def create_model_version(self, model_provider: str, model_name: str, version_config: Dict[str, Any]): - """Create a new model version.""" - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_model_version(self, model_provider: str, model_name: str, version_id: str): - """Get model version details.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}" - return await self._send_request("GET", url) - - async def promote_model_version(self, model_provider: str, model_name: str, version_id: str): - """Promote model version to production.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/promote" - return await self._send_request("POST", url) - - async def rollback_model_version(self, model_provider: str, model_name: str, version_id: str): - """Rollback to a specific model version.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # Model Registry APIs - async def list_registry_models(self, page: int = 1, limit: int = 20, filter: str | None = None): - """List models in registry.""" - params = {"page": page, "limit": limit} - if filter: - params["filter"] = filter - return await self._send_request("GET", "/models/registry", params=params) - - async def register_model(self, model_config: Dict[str, Any]): - """Register a new model in the registry.""" - return await self._send_request("POST", "/models/registry", json=model_config) - - async def get_registry_model(self, model_id: str): - """Get registered model details.""" - url = f"/models/registry/{model_id}" - return await self._send_request("GET", url) - - async def update_registry_model(self, model_id: str, model_config: Dict[str, Any]): - """Update registered model information.""" - url = f"/models/registry/{model_id}" - return await self._send_request("PUT", url, json=model_config) - - async def unregister_model(self, model_id: str): - """Unregister a model from the registry.""" - url = f"/models/registry/{model_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedAppClient(AsyncDifyClient): - """Async Advanced App Configuration APIs for comprehensive app management.""" - - # App Creation and Management APIs - async def create_app(self, app_config: Dict[str, Any]): - """Create a new application.""" - return await self._send_request("POST", "/apps", json=app_config) - - async def list_apps( - self, - page: int = 1, - limit: int = 20, - app_type: str | None = None, - status: str | None = None, - ): - """List applications with filtering.""" - params = {"page": page, "limit": limit} - if app_type: - params["app_type"] = app_type - if status: - params["status"] = status - return await self._send_request("GET", "/apps", params=params) - - async def get_app(self, app_id: str): - """Get detailed application information.""" - url = f"/apps/{app_id}" - return await self._send_request("GET", url) - - async def update_app(self, app_id: str, app_config: Dict[str, Any]): - """Update application configuration.""" - url = f"/apps/{app_id}" - return await self._send_request("PUT", url, json=app_config) - - async def delete_app(self, app_id: str): - """Delete an application.""" - url = f"/apps/{app_id}" - return await self._send_request("DELETE", url) - - async def duplicate_app(self, app_id: str, duplicate_config: Dict[str, Any]): - """Duplicate an application.""" - url = f"/apps/{app_id}/duplicate" - return await self._send_request("POST", url, json=duplicate_config) - - async def archive_app(self, app_id: str): - """Archive an application.""" - url = f"/apps/{app_id}/archive" - return await self._send_request("POST", url) - - async def restore_app(self, app_id: str): - """Restore an archived application.""" - url = f"/apps/{app_id}/restore" - return await self._send_request("POST", url) - - # App Publishing and Versioning APIs - async def publish_app(self, app_id: str, publish_config: Dict[str, Any] | None = None): - """Publish an application.""" - data = publish_config or {} - url = f"/apps/{app_id}/publish" - return await self._send_request("POST", url, json=data) - - async def unpublish_app(self, app_id: str): - """Unpublish an application.""" - url = f"/apps/{app_id}/unpublish" - return await self._send_request("POST", url) - - async def list_app_versions(self, app_id: str, page: int = 1, limit: int = 20): - """List application versions.""" - params = {"page": page, "limit": limit} - url = f"/apps/{app_id}/versions" - return await self._send_request("GET", url, params=params) - - async def create_app_version(self, app_id: str, version_config: Dict[str, Any]): - """Create a new application version.""" - url = f"/apps/{app_id}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_app_version(self, app_id: str, version_id: str): - """Get application version details.""" - url = f"/apps/{app_id}/versions/{version_id}" - return await self._send_request("GET", url) - - async def rollback_app_version(self, app_id: str, version_id: str): - """Rollback application to a specific version.""" - url = f"/apps/{app_id}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # App Template APIs - async def list_app_templates(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available app templates.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/apps/templates", params=params) - - async def get_app_template(self, template_id: str): - """Get app template details.""" - url = f"/apps/templates/{template_id}" - return await self._send_request("GET", url) - - async def create_app_from_template(self, template_id: str, app_config: Dict[str, Any]): - """Create an app from a template.""" - url = f"/apps/templates/{template_id}/create" - return await self._send_request("POST", url, json=app_config) - - async def create_custom_template(self, app_id: str, template_config: Dict[str, Any]): - """Create a custom template from an existing app.""" - url = f"/apps/{app_id}/create-template" - return await self._send_request("POST", url, json=template_config) - - # App Analytics and Metrics APIs - async def get_app_analytics( - self, - app_id: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get application analytics.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/apps/{app_id}/analytics" - return await self._send_request("GET", url, params=params) - - async def get_app_user_feedback(self, app_id: str, page: int = 1, limit: int = 20, rating: int | None = None): - """Get user feedback for an application.""" - params = {"page": page, "limit": limit} - if rating: - params["rating"] = rating - url = f"/apps/{app_id}/feedback" - return await self._send_request("GET", url, params=params) - - async def get_app_error_logs( - self, - app_id: str, - start_date: str, - end_date: str, - error_type: str | None = None, - page: int = 1, - limit: int = 20, - ): - """Get application error logs.""" - params = { - "start_date": start_date, - "end_date": end_date, - "page": page, - "limit": limit, - } - if error_type: - params["error_type"] = error_type - url = f"/apps/{app_id}/errors" - return await self._send_request("GET", url, params=params) - - # Advanced Configuration APIs - async def get_app_advanced_config(self, app_id: str): - """Get advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("GET", url) - - async def update_app_advanced_config(self, app_id: str, config: Dict[str, Any]): - """Update advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("PUT", url, json=config) - - async def get_app_environment_variables(self, app_id: str): - """Get application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("GET", url) - - async def update_app_environment_variables(self, app_id: str, variables: Dict[str, str]): - """Update application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("PUT", url, json=variables) - - async def get_app_resource_limits(self, app_id: str): - """Get application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("GET", url) - - async def update_app_resource_limits(self, app_id: str, limits: Dict[str, Any]): - """Update application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("PUT", url, json=limits) - - # App Integration APIs - async def get_app_integrations(self, app_id: str): - """Get application integrations.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("GET", url) - - async def add_app_integration(self, app_id: str, integration_config: Dict[str, Any]): - """Add integration to application.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("POST", url, json=integration_config) - - async def update_app_integration(self, app_id: str, integration_id: str, config: Dict[str, Any]): - """Update application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("PUT", url, json=config) - - async def remove_app_integration(self, app_id: str, integration_id: str): - """Remove integration from application.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("DELETE", url) - - async def test_app_integration(self, app_id: str, integration_id: str): - """Test application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}/test" - return await self._send_request("POST", url) diff --git a/sdks/python-client/dify_client/base_client.py b/sdks/python-client/dify_client/base_client.py deleted file mode 100644 index 0ad6e07b23..0000000000 --- a/sdks/python-client/dify_client/base_client.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Base client with common functionality for both sync and async clients.""" - -import json -import time -import logging -from typing import Dict, Callable, Optional - -try: - # Python 3.10+ - from typing import ParamSpec -except ImportError: - # Python < 3.10 - from typing_extensions import ParamSpec - -from urllib.parse import urljoin - -import httpx - -P = ParamSpec("P") - -from .exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, -) - - -class BaseClientMixin: - """Mixin class providing common functionality for Dify clients.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the base client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds - max_retries: Maximum number of retry attempts - retry_delay: Delay between retries in seconds - enable_logging: Enable detailed logging - """ - if not api_key: - raise ValidationError("API key is required") - - self.api_key = api_key - self.base_url = base_url.rstrip("/") - self.timeout = timeout - self.max_retries = max_retries - self.retry_delay = retry_delay - self.enable_logging = enable_logging - - # Setup logging - self.logger = logging.getLogger(f"dify_client.{self.__class__.__name__.lower()}") - if enable_logging and not self.logger.handlers: - # Create console handler with formatter - handler = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.logger.setLevel(logging.INFO) - self.enable_logging = True - else: - self.enable_logging = enable_logging - - def _get_headers(self, content_type: str = "application/json") -> Dict[str, str]: - """Get common request headers.""" - return { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": content_type, - "User-Agent": "dify-client-python/0.1.12", - } - - def _build_url(self, endpoint: str) -> str: - """Build full URL from endpoint.""" - return urljoin(self.base_url + "/", endpoint.lstrip("/")) - - def _handle_response(self, response: httpx.Response) -> httpx.Response: - """Handle HTTP response and raise appropriate exceptions.""" - try: - if response.status_code == 401: - raise AuthenticationError( - "Authentication failed. Check your API key.", - status_code=response.status_code, - response=response.json() if response.content else None, - ) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError( - "Rate limit exceeded. Please try again later.", - retry_after=int(retry_after) if retry_after else None, - ) - elif response.status_code >= 400: - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except: - message = f"HTTP {response.status_code}: {response.text}" - - raise APIError( - message, - status_code=response.status_code, - response=response.json() if response.content else None, - ) - - return response - - except json.JSONDecodeError: - raise APIError( - f"Invalid JSON response: {response.text}", - status_code=response.status_code, - ) - - def _retry_request( - self, - request_func: Callable[P, httpx.Response], - request_context: str | None = None, - *args: P.args, - **kwargs: P.kwargs, - ) -> httpx.Response: - """Retry a request with exponential backoff. - - Args: - request_func: Function that performs the HTTP request - request_context: Context description for logging (e.g., "GET /v1/messages") - *args: Positional arguments to pass to request_func - **kwargs: Keyword arguments to pass to request_func - - Returns: - httpx.Response: Successful response - - Raises: - NetworkError: On network failures after retries - TimeoutError: On timeout failures after retries - APIError: On API errors (4xx/5xx responses) - DifyClientError: On unexpected failures - """ - last_exception = None - - for attempt in range(self.max_retries + 1): - try: - response = request_func(*args, **kwargs) - return response # Let caller handle response processing - - except (httpx.NetworkError, httpx.TimeoutException) as e: - last_exception = e - context_msg = f" {request_context}" if request_context else "" - - if attempt < self.max_retries: - delay = self.retry_delay * (2**attempt) # Exponential backoff - self.logger.warning( - f"Request failed{context_msg} (attempt {attempt + 1}/{self.max_retries + 1}): {e}. " - f"Retrying in {delay:.2f} seconds..." - ) - time.sleep(delay) - else: - self.logger.error(f"Request failed{context_msg} after {self.max_retries + 1} attempts: {e}") - # Convert to custom exceptions - if isinstance(e, httpx.TimeoutException): - from .exceptions import TimeoutError - - raise TimeoutError(f"Request timed out after {self.max_retries} retries{context_msg}") from e - else: - from .exceptions import NetworkError - - raise NetworkError( - f"Network error after {self.max_retries} retries{context_msg}: {str(e)}" - ) from e - - if last_exception: - raise last_exception - raise DifyClientError("Request failed after retries") - - def _validate_params(self, **params) -> None: - """Validate request parameters.""" - for key, value in params.items(): - if value is None: - continue - - # String validations - if isinstance(value, str): - if not value.strip(): - raise ValidationError(f"Parameter '{key}' cannot be empty or whitespace only") - if len(value) > 10000: - raise ValidationError(f"Parameter '{key}' exceeds maximum length of 10000 characters") - - # List validations - elif isinstance(value, list): - if len(value) > 1000: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 1000 items") - - # Dictionary validations - elif isinstance(value, dict): - if len(value) > 100: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 100 items") - - # Type-specific validations - if key == "user" and not isinstance(value, str): - raise ValidationError(f"Parameter '{key}' must be a string") - elif key in ["page", "limit", "page_size"] and not isinstance(value, int): - raise ValidationError(f"Parameter '{key}' must be an integer") - elif key == "files" and not isinstance(value, (list, dict)): - raise ValidationError(f"Parameter '{key}' must be a list or dict") - elif key == "rating" and value not in ["like", "dislike"]: - raise ValidationError(f"Parameter '{key}' must be 'like' or 'dislike'") - - def _log_request(self, method: str, url: str, **kwargs) -> None: - """Log request details.""" - self.logger.info(f"Making {method} request to {url}") - if kwargs.get("json"): - self.logger.debug(f"Request body: {kwargs['json']}") - if kwargs.get("params"): - self.logger.debug(f"Query params: {kwargs['params']}") - - def _log_response(self, response: httpx.Response) -> None: - """Log response details.""" - self.logger.info(f"Received response: {response.status_code} ({len(response.content)} bytes)") diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py deleted file mode 100644 index cebdf6845c..0000000000 --- a/sdks/python-client/dify_client/client.py +++ /dev/null @@ -1,1267 +0,0 @@ -import json -import logging -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import httpx -from .base_client import BaseClientMixin -from .exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - FileUploadError, -) - - -class DifyClient(BaseClientMixin): - """Synchronous Dify API client. - - This client uses httpx.Client for efficient connection pooling and resource management. - It's recommended to use this client as a context manager: - - Example: - with DifyClient(api_key="your-key") as client: - response = client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - max_retries: Maximum number of retry attempts (default: 3) - retry_delay: Delay between retries in seconds (default: 1.0) - enable_logging: Whether to enable request logging (default: True) - """ - # Initialize base client functionality - BaseClientMixin.__init__(self, api_key, base_url, timeout, max_retries, retry_delay, enable_logging) - - self._client = httpx.Client( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - def __enter__(self): - """Support context manager protocol.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting context.""" - self.close() - - def close(self): - """Close the HTTP client and release resources.""" - if hasattr(self, "_client"): - self._client.close() - - def _send_request( - self, - method: str, - endpoint: str, - json: Dict[str, Any] | None = None, - params: Dict[str, Any] | None = None, - stream: bool = False, - **kwargs, - ): - """Send an HTTP request to the Dify API with retry logic. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - # Validate parameters - if json: - self._validate_params(**json) - if params: - self._validate_params(**params) - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - def make_request(): - """Inner function to perform the actual HTTP request.""" - # Log request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} request to {endpoint}") - # Debug logging for detailed information - if self.logger.isEnabledFor(logging.DEBUG): - if json: - self.logger.debug(f"Request body: {json}") - if params: - self.logger.debug(f"Request params: {params}") - - # httpx.Client automatically prepends base_url - response = self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received response: {response.status_code}") - - return response - - # Use the retry mechanism from base client - request_context = f"{method} {endpoint}" - response = self._retry_request(make_request, request_context) - - # Handle error responses (API errors don't retry) - self._handle_error_response(response) - - return response - - def _handle_error_response(self, response, is_upload_request: bool = False) -> None: - """Handle HTTP error responses and raise appropriate exceptions.""" - - if response.status_code < 400: - return # Success response - - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except (ValueError, KeyError): - message = f"HTTP {response.status_code}" - error_data = None - - # Log error response if logging is enabled - if self.enable_logging: - self.logger.error(f"API error: {response.status_code} - {message}") - - if response.status_code == 401: - raise AuthenticationError(message, response.status_code, error_data) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError(message, retry_after) - elif response.status_code == 422: - raise ValidationError(message, response.status_code, error_data) - elif response.status_code == 400: - # Check if this is a file upload error based on the URL or context - current_url = getattr(response, "url", "") or "" - if is_upload_request or "upload" in str(current_url).lower() or "files" in str(current_url).lower(): - raise FileUploadError(message, response.status_code, error_data) - else: - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 500: - # Server errors should raise APIError - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 400: - raise APIError(message, response.status_code, error_data) - - def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - # Log file upload request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} file upload request to {endpoint}") - self.logger.debug(f"Form data: {data}") - self.logger.debug(f"Files: {files}") - - response = self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received file upload response: {response.status_code}") - - # Handle error responses - self._handle_error_response(response, is_upload_request=True) - - return response - - def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - self._validate_params(message_id=message_id, rating=rating, user=user) - data = {"rating": rating, "user": user} - return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - def get_application_parameters(self, user: str): - params = {"user": user} - return self._send_request("GET", "/parameters", params=params) - - def file_upload(self, user: str, files: dict): - data = {"user": user} - return self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - def text_to_audio(self, text: str, user: str, streaming: bool = False): - data = {"text": text, "user": user, "streaming": streaming} - return self._send_request("POST", "/text-to-audio", json=data) - - def get_meta(self, user: str): - params = {"user": user} - return self._send_request("GET", "/meta", params=params) - - def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return self._send_request("GET", "/info") - - def get_app_site_info(self): - """Get application site information.""" - return self._send_request("GET", "/site") - - def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("GET", url) - - def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("PUT", url, json=config_data) - - def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return self._send_request("GET", url) - - def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return self._send_request("POST", url, json=data) - - def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return self._send_request("DELETE", url) - - -class CompletionClient(DifyClient): - def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, response_mode=response_mode, user=user) - - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class ChatClient(DifyClient): - def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if not isinstance(query, str) or not query.strip(): - raise ValidationError("query must be a non-empty string") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, query=query, user=user, response_mode=response_mode) - - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - def get_suggested(self, message_id: str, user: str): - params = {"user": user} - return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - def stop_message(self, task_id: str, user: str): - data = {"user": user} - return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return self._send_request("GET", "/conversations", params=params) - - def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - params = {"user": user} - - if conversation_id: - params["conversation_id"] = conversation_id - if first_id: - params["first_id"] = first_id - if limit: - params["limit"] = limit - - return self._send_request("GET", "/messages", params=params) - - def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - def delete_conversation(self, conversation_id: str, user: str): - data = {"user": user} - return self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - data = {"user": user} - files = {"file": audio_file} - return self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Conversation Variables APIs - def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return self._send_request("DELETE", url) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - # Enhanced Annotation APIs - def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return self._send_request("GET", url) - - def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations with pagination.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation_with_response(self, question: str, answer: str): - """Create an annotation with full response handling.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return self._send_request("PUT", url, json=data) - - -class WorkflowClient(DifyClient): - def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request("POST", "/workflows/run", data) - - def stop(self, task_id, user): - data = {"user": user} - return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - def get_result(self, workflow_run_id): - return self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - if status: - params["status"] = status - if created_at__before: - params["created_at__before"] = created_at__before - if created_at__after: - params["created_at__after"] = created_at__after - if created_by_end_user_session_id: - params["created_by_end_user_session_id"] = created_by_end_user_session_id - if created_by_account: - params["created_by_account"] = created_by_account - return self._send_request("GET", "/workflows/logs", params=params) - - def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("GET", url) - - def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("PUT", url, json=workflow_data) - - def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return self._send_request("POST", url) - - def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return self._send_request("GET", url, params=params) - - -class WorkspaceClient(DifyClient): - """Client for workspace-related operations.""" - - def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_model_providers(self): - """Get all model providers.""" - return self._send_request("GET", "/workspaces/current/model-providers") - - def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return self._send_request("GET", url) - - def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return self._send_request("POST", url, json=credentials) - - # File Management APIs - def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return self._send_request("GET", url) - - def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return self._send_request("GET", url) - - def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return self._send_request("DELETE", url) - - -class KnowledgeBaseClient(DifyClient): - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - ): - """ - Construct a KnowledgeBaseClient object. - - Args: - api_key (str): API key of Dify. - base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. - dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to - create a new dataset. or list datasets. otherwise you need to set this. - """ - super().__init__(api_key=api_key, base_url=base_url) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - def create_dataset(self, name: str, **kwargs): - return self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - def create_document_by_text(self, name, text, extra_params: Dict[str, Any] | None = None, **kwargs): - """ - Create a document by text. - - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict[str, Any] | None = None, - **kwargs, - ): - """ - Update a document by text. - - :param document_id: ID of the document - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict[str, Any] | None = None, - ): - """ - Create a document by file. - - :param file_path: Path to the file - :param original_document_id: pass this ID if you want to replace the original document (optional) - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def update_document_by_file( - self, - document_id: str, - file_path: str, - extra_params: Dict[str, Any] | None = None, - ): - """ - Update a document by file. - - :param document_id: ID of the document - :param file_path: Path to the file - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def batch_indexing_status(self, batch_id: str, **kwargs): - """ - Get the status of the batch indexing. - - :param batch_id: ID of the batch uploading - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return self._send_request("GET", url, **kwargs) - - def delete_dataset(self): - """ - Delete this dataset. - - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}" - return self._send_request("DELETE", url) - - def delete_document(self, document_id: str): - """ - Delete a document. - - :param document_id: ID of the document - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return self._send_request("DELETE", url) - - def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """ - Get a list of documents in this dataset. - - :return: Response from the API - """ - params = {} - if page is not None: - params["page"] = page - if page_size is not None: - params["limit"] = page_size - if keyword is not None: - params["keyword"] = keyword - url = f"/datasets/{self._get_dataset_id()}/documents" - return self._send_request("GET", url, params=params, **kwargs) - - def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """ - Add segments to a document. - - :param document_id: ID of the document - :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] - :return: Response from the API - """ - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return self._send_request("POST", url, json=data, **kwargs) - - def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """ - Query segments in this document. - - :param document_id: ID of the document - :param keyword: query keyword, optional - :param status: status of the segment, optional, e.g. completed - :param kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = {} - if keyword is not None: - params["keyword"] = keyword - if status is not None: - params["status"] = status - if "params" in kwargs: - params.update(kwargs.pop("params")) - return self._send_request("GET", url, params=params, **kwargs) - - def delete_document_segment(self, document_id: str, segment_id: str): - """ - Delete a segment from a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("DELETE", url) - - def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """ - Update a segment in a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} - :return: Response from the API - """ - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return self._send_request("POST", url, json=data) - - def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("GET", url) - - def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("POST", url, json=metadata_data) - - def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return self._send_request("PATCH", url, json=metadata_data) - - def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return self._send_request("GET", url) - - def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return self._send_request("POST", url, json=data) - - def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return self._send_request("POST", url, json=data) - - # Dataset Tags APIs - def list_dataset_tags(self): - """List all dataset tags.""" - return self._send_request("GET", "/datasets/tags") - - def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/binding", json=data) - - def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/unbinding", json=data) - - def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return self._send_request("GET", url) - - # RAG Pipeline APIs - def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return self._send_request("GET", url, params=params) - - def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return self._send_request("POST", url, json=data, stream=True) - - def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API containing dataset details including: - - name, description, permission - - indexing_technique, embedding_model, embedding_model_provider - - retrieval_model configuration - - document_count, word_count, app_count - - created_at, updated_at - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return self._send_request("GET", url) - - def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - # Build data dictionary with all possible parameters - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - # Filter out None values and merge with additional kwargs - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return self._send_request("PATCH", url, json=data) - - def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status (enable/disable/archive/unarchive). - - Args: - action: Action to perform on documents - - 'enable': Enable documents for retrieval - - 'disable': Disable documents from retrieval - - 'archive': Archive documents - - 'un_archive': Unarchive documents - document_ids: List of document IDs to update - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API with operation result - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return self._send_request("POST", "/datasets/from-template", json=data) - - def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return self._send_request("POST", url, json=data) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) diff --git a/sdks/python-client/dify_client/exceptions.py b/sdks/python-client/dify_client/exceptions.py deleted file mode 100644 index e7ba2ff4b2..0000000000 --- a/sdks/python-client/dify_client/exceptions.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Custom exceptions for the Dify client.""" - -from typing import Optional, Dict, Any - - -class DifyClientError(Exception): - """Base exception for all Dify client errors.""" - - def __init__(self, message: str, status_code: int | None = None, response: Dict[str, Any] | None = None): - super().__init__(message) - self.message = message - self.status_code = status_code - self.response = response - - -class APIError(DifyClientError): - """Raised when the API returns an error response.""" - - def __init__(self, message: str, status_code: int, response: Dict[str, Any] | None = None): - super().__init__(message, status_code, response) - self.status_code = status_code - - -class AuthenticationError(DifyClientError): - """Raised when authentication fails.""" - - pass - - -class RateLimitError(DifyClientError): - """Raised when rate limit is exceeded.""" - - def __init__(self, message: str = "Rate limit exceeded", retry_after: int | None = None): - super().__init__(message) - self.retry_after = retry_after - - -class ValidationError(DifyClientError): - """Raised when request validation fails.""" - - pass - - -class NetworkError(DifyClientError): - """Raised when network-related errors occur.""" - - pass - - -class TimeoutError(DifyClientError): - """Raised when request times out.""" - - pass - - -class FileUploadError(DifyClientError): - """Raised when file upload fails.""" - - pass - - -class DatasetError(DifyClientError): - """Raised when dataset operations fail.""" - - pass - - -class WorkflowError(DifyClientError): - """Raised when workflow operations fail.""" - - pass diff --git a/sdks/python-client/dify_client/models.py b/sdks/python-client/dify_client/models.py deleted file mode 100644 index 0321e9c3f4..0000000000 --- a/sdks/python-client/dify_client/models.py +++ /dev/null @@ -1,396 +0,0 @@ -"""Response models for the Dify client with proper type hints.""" - -from typing import Optional, List, Dict, Any, Literal, Union -from dataclasses import dataclass, field -from datetime import datetime - - -@dataclass -class BaseResponse: - """Base response model.""" - - success: bool = True - message: str | None = None - - -@dataclass -class ErrorResponse(BaseResponse): - """Error response model.""" - - error_code: str | None = None - details: Dict[str, Any] | None = None - success: bool = False - - -@dataclass -class FileInfo: - """File information model.""" - - id: str - name: str - size: int - mime_type: str - url: str | None = None - created_at: datetime | None = None - - -@dataclass -class MessageResponse(BaseResponse): - """Message response model.""" - - id: str = "" - answer: str = "" - conversation_id: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - files: List[Dict[str, Any]] | None = None - - -@dataclass -class ConversationResponse(BaseResponse): - """Conversation response model.""" - - id: str = "" - name: str = "" - inputs: Dict[str, Any] | None = None - status: str | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetResponse(BaseResponse): - """Dataset response model.""" - - id: str = "" - name: str = "" - description: str | None = None - permission: str | None = None - indexing_technique: str | None = None - embedding_model: str | None = None - embedding_model_provider: str | None = None - retrieval_model: Dict[str, Any] | None = None - document_count: int | None = None - word_count: int | None = None - app_count: int | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DocumentResponse(BaseResponse): - """Document response model.""" - - id: str = "" - name: str = "" - data_source_type: str | None = None - data_source_info: Dict[str, Any] | None = None - dataset_process_rule_id: str | None = None - batch: str | None = None - position: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - archived: bool | None = None - archived_reason: str | None = None - archived_at: float | None = None - archived_by: str | None = None - word_count: int | None = None - hit_count: int | None = None - doc_form: str | None = None - doc_metadata: Dict[str, Any] | None = None - created_at: float | None = None - updated_at: float | None = None - indexing_status: str | None = None - completed_at: float | None = None - paused_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class DocumentSegmentResponse(BaseResponse): - """Document segment response model.""" - - id: str = "" - position: int | None = None - document_id: str | None = None - content: str | None = None - answer: str | None = None - word_count: int | None = None - tokens: int | None = None - keywords: List[str] | None = None - index_node_id: str | None = None - index_node_hash: str | None = None - hit_count: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - status: str | None = None - created_by: str | None = None - created_at: float | None = None - indexing_at: float | None = None - completed_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class WorkflowRunResponse(BaseResponse): - """Workflow run response model.""" - - id: str = "" - workflow_id: str | None = None - status: Literal["running", "succeeded", "failed", "stopped"] | None = None - inputs: Dict[str, Any] | None = None - outputs: Dict[str, Any] | None = None - error: str | None = None - elapsed_time: float | None = None - total_tokens: int | None = None - total_steps: int | None = None - created_at: float | None = None - finished_at: float | None = None - - -@dataclass -class ApplicationParametersResponse(BaseResponse): - """Application parameters response model.""" - - opening_statement: str | None = None - suggested_questions: List[str] | None = None - speech_to_text: Dict[str, Any] | None = None - text_to_speech: Dict[str, Any] | None = None - retriever_resource: Dict[str, Any] | None = None - sensitive_word_avoidance: Dict[str, Any] | None = None - file_upload: Dict[str, Any] | None = None - system_parameters: Dict[str, Any] | None = None - user_input_form: List[Dict[str, Any]] | None = None - - -@dataclass -class AnnotationResponse(BaseResponse): - """Annotation response model.""" - - id: str = "" - question: str = "" - answer: str = "" - content: str | None = None - created_at: float | None = None - updated_at: float | None = None - created_by: str | None = None - updated_by: str | None = None - hit_count: int | None = None - - -@dataclass -class PaginatedResponse(BaseResponse): - """Paginated response model.""" - - data: List[Any] = field(default_factory=list) - has_more: bool = False - limit: int = 0 - total: int = 0 - page: int | None = None - - -@dataclass -class ConversationVariableResponse(BaseResponse): - """Conversation variable response model.""" - - conversation_id: str = "" - variables: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class FileUploadResponse(BaseResponse): - """File upload response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: float | None = None - - -@dataclass -class AudioResponse(BaseResponse): - """Audio generation/response model.""" - - audio: str | None = None # Base64 encoded audio data or URL - audio_url: str | None = None - duration: float | None = None - sample_rate: int | None = None - - -@dataclass -class SuggestedQuestionsResponse(BaseResponse): - """Suggested questions response model.""" - - message_id: str = "" - questions: List[str] = field(default_factory=list) - - -@dataclass -class AppInfoResponse(BaseResponse): - """App info response model.""" - - id: str = "" - name: str = "" - description: str | None = None - icon: str | None = None - icon_background: str | None = None - mode: str | None = None - tags: List[str] | None = None - enable_site: bool | None = None - enable_api: bool | None = None - api_token: str | None = None - - -@dataclass -class WorkspaceModelsResponse(BaseResponse): - """Workspace models response model.""" - - models: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class HitTestingResponse(BaseResponse): - """Hit testing response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class DatasetTagsResponse(BaseResponse): - """Dataset tags response model.""" - - tags: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class WorkflowLogsResponse(BaseResponse): - """Workflow logs response model.""" - - logs: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - page: int = 0 - limit: int = 0 - has_more: bool = False - - -@dataclass -class ModelProviderResponse(BaseResponse): - """Model provider response model.""" - - provider_name: str = "" - provider_type: str = "" - models: List[Dict[str, Any]] = field(default_factory=list) - is_enabled: bool = False - credentials: Dict[str, Any] | None = None - - -@dataclass -class FileInfoResponse(BaseResponse): - """File info response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - - -@dataclass -class WorkflowDraftResponse(BaseResponse): - """Workflow draft response model.""" - - id: str = "" - app_id: str = "" - draft_data: Dict[str, Any] = field(default_factory=dict) - version: int = 0 - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class ApiTokenResponse(BaseResponse): - """API token response model.""" - - id: str = "" - name: str = "" - token: str = "" - description: str | None = None - created_at: int | None = None - last_used_at: int | None = None - is_active: bool = True - - -@dataclass -class JobStatusResponse(BaseResponse): - """Job status response model.""" - - job_id: str = "" - job_status: str = "" - error_msg: str | None = None - progress: float | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetQueryResponse(BaseResponse): - """Dataset query response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - search_time: float | None = None - retrieval_model: Dict[str, Any] | None = None - - -@dataclass -class DatasetTemplateResponse(BaseResponse): - """Dataset template response model.""" - - template_name: str = "" - display_name: str = "" - description: str = "" - category: str = "" - icon: str | None = None - config_schema: Dict[str, Any] = field(default_factory=dict) - - -# Type aliases for common response types -ResponseType = Union[ - BaseResponse, - ErrorResponse, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -] diff --git a/sdks/python-client/examples/advanced_usage.py b/sdks/python-client/examples/advanced_usage.py deleted file mode 100644 index bc8720bef2..0000000000 --- a/sdks/python-client/examples/advanced_usage.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -Advanced usage examples for the Dify Python SDK. - -This example demonstrates: -- Error handling and retries -- Logging configuration -- Context managers -- Async usage -- File uploads -- Dataset management -""" - -import asyncio -import logging -from pathlib import Path - -from dify_client import ( - ChatClient, - CompletionClient, - AsyncChatClient, - KnowledgeBaseClient, - DifyClient, -) -from dify_client.exceptions import ( - APIError, - RateLimitError, - AuthenticationError, - DifyClientError, -) - - -def setup_logging(): - """Setup logging for the SDK.""" - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - -def example_chat_with_error_handling(): - """Example of chat with comprehensive error handling.""" - api_key = "your-api-key-here" - - try: - with ChatClient(api_key, enable_logging=True) as client: - # Simple chat message - response = client.create_chat_message( - inputs={}, query="Hello, how are you?", user="user-123", response_mode="blocking" - ) - - result = response.json() - print(f"Response: {result.get('answer')}") - - except AuthenticationError as e: - print(f"Authentication failed: {e}") - print("Please check your API key") - - except RateLimitError as e: - print(f"Rate limit exceeded: {e}") - if e.retry_after: - print(f"Retry after {e.retry_after} seconds") - - except APIError as e: - print(f"API error: {e.message}") - print(f"Status code: {e.status_code}") - - except DifyClientError as e: - print(f"Dify client error: {e}") - - except Exception as e: - print(f"Unexpected error: {e}") - - -def example_completion_with_files(): - """Example of completion with file upload.""" - api_key = "your-api-key-here" - - with CompletionClient(api_key) as client: - # Upload an image file first - file_path = "path/to/your/image.jpg" - - try: - with open(file_path, "rb") as f: - files = {"file": (Path(file_path).name, f, "image/jpeg")} - upload_response = client.file_upload("user-123", files) - upload_response.raise_for_status() - - file_id = upload_response.json().get("id") - print(f"File uploaded with ID: {file_id}") - - # Use the uploaded file in completion - files_list = [{"type": "image", "transfer_method": "local_file", "upload_file_id": file_id}] - - completion_response = client.create_completion_message( - inputs={"query": "Describe this image"}, response_mode="blocking", user="user-123", files=files_list - ) - - result = completion_response.json() - print(f"Completion result: {result.get('answer')}") - - except FileNotFoundError: - print(f"File not found: {file_path}") - except Exception as e: - print(f"Error during file upload/completion: {e}") - - -def example_dataset_management(): - """Example of dataset management operations.""" - api_key = "your-api-key-here" - - with KnowledgeBaseClient(api_key) as kb_client: - try: - # Create a new dataset - create_response = kb_client.create_dataset(name="My Test Dataset") - create_response.raise_for_status() - - dataset_id = create_response.json().get("id") - print(f"Created dataset with ID: {dataset_id}") - - # Create a client with the dataset ID - dataset_client = KnowledgeBaseClient(api_key, dataset_id=dataset_id) - - # Add a document by text - doc_response = dataset_client.create_document_by_text( - name="Test Document", text="This is a test document for the knowledge base." - ) - doc_response.raise_for_status() - - document_id = doc_response.json().get("document", {}).get("id") - print(f"Created document with ID: {document_id}") - - # List documents - list_response = dataset_client.list_documents() - list_response.raise_for_status() - - documents = list_response.json().get("data", []) - print(f"Dataset contains {len(documents)} documents") - - # Update dataset configuration - update_response = dataset_client.update_dataset( - name="Updated Dataset Name", description="Updated description", indexing_technique="high_quality" - ) - update_response.raise_for_status() - - print("Dataset updated successfully") - - except Exception as e: - print(f"Dataset management error: {e}") - - -async def example_async_chat(): - """Example of async chat usage.""" - api_key = "your-api-key-here" - - try: - async with AsyncChatClient(api_key) as client: - # Create chat message - response = await client.create_chat_message( - inputs={}, query="What's the weather like?", user="user-456", response_mode="blocking" - ) - - result = response.json() - print(f"Async response: {result.get('answer')}") - - # Get conversations - conversations = await client.get_conversations("user-456") - conversations.raise_for_status() - - conv_data = conversations.json() - print(f"Found {len(conv_data.get('data', []))} conversations") - - except Exception as e: - print(f"Async chat error: {e}") - - -def example_streaming_response(): - """Example of handling streaming responses.""" - api_key = "your-api-key-here" - - with ChatClient(api_key) as client: - try: - response = client.create_chat_message( - inputs={}, query="Tell me a story", user="user-789", response_mode="streaming" - ) - - print("Streaming response:") - for line in response.iter_lines(decode_unicode=True): - if line.startswith("data:"): - data = line[5:].strip() - if data: - import json - - try: - chunk = json.loads(data) - answer = chunk.get("answer", "") - if answer: - print(answer, end="", flush=True) - except json.JSONDecodeError: - continue - print() # New line after streaming - - except Exception as e: - print(f"Streaming error: {e}") - - -def example_application_info(): - """Example of getting application information.""" - api_key = "your-api-key-here" - - with DifyClient(api_key) as client: - try: - # Get app info - info_response = client.get_app_info() - info_response.raise_for_status() - - app_info = info_response.json() - print(f"App name: {app_info.get('name')}") - print(f"App mode: {app_info.get('mode')}") - print(f"App tags: {app_info.get('tags', [])}") - - # Get app parameters - params_response = client.get_application_parameters("user-123") - params_response.raise_for_status() - - params = params_response.json() - print(f"Opening statement: {params.get('opening_statement')}") - print(f"Suggested questions: {params.get('suggested_questions', [])}") - - except Exception as e: - print(f"App info error: {e}") - - -def main(): - """Run all examples.""" - setup_logging() - - print("=== Dify Python SDK Advanced Usage Examples ===\n") - - print("1. Chat with Error Handling:") - example_chat_with_error_handling() - print() - - print("2. Completion with Files:") - example_completion_with_files() - print() - - print("3. Dataset Management:") - example_dataset_management() - print() - - print("4. Async Chat:") - asyncio.run(example_async_chat()) - print() - - print("5. Streaming Response:") - example_streaming_response() - print() - - print("6. Application Info:") - example_application_info() - print() - - print("All examples completed!") - - -if __name__ == "__main__": - main() diff --git a/sdks/python-client/pyproject.toml b/sdks/python-client/pyproject.toml deleted file mode 100644 index a25cb9150c..0000000000 --- a/sdks/python-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "dify-client" -version = "0.1.12" -description = "A package for interacting with the Dify Service-API" -readme = "README.md" -requires-python = ">=3.10" -dependencies = [ - "httpx[http2]>=0.27.0", - "aiofiles>=23.0.0", -] -authors = [ - {name = "Dify", email = "hello@dify.ai"} -] -license = {text = "MIT"} -keywords = ["dify", "nlp", "ai", "language-processing"] -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] - -[project.urls] -Homepage = "https://github.com/langgenius/dify" - -[project.optional-dependencies] -dev = [ - "pytest>=7.0.0", - "pytest-asyncio>=0.21.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["dify_client"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" diff --git a/sdks/python-client/tests/__init__.py b/sdks/python-client/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sdks/python-client/tests/test_async_client.py b/sdks/python-client/tests/test_async_client.py deleted file mode 100644 index 4f5001866f..0000000000 --- a/sdks/python-client/tests/test_async_client.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for async client implementation in the Python SDK. - -This test validates the async/await functionality using httpx.AsyncClient -and ensures API parity with sync clients. -""" - -import unittest -from unittest.mock import Mock, patch, AsyncMock - -from dify_client.async_client import ( - AsyncDifyClient, - AsyncChatClient, - AsyncCompletionClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, - AsyncKnowledgeBaseClient, -) - - -class TestAsyncAPIParity(unittest.TestCase): - """Test that async clients have API parity with sync clients.""" - - def test_dify_client_api_parity(self): - """Test AsyncDifyClient has same methods as DifyClient.""" - from dify_client import DifyClient - - sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")} - - # aclose is async-specific, close is sync-specific - sync_methods.discard("close") - async_methods.discard("aclose") - - # Verify parity - self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient") - - def test_chat_client_api_parity(self): - """Test AsyncChatClient has same methods as ChatClient.""" - from dify_client import ChatClient - - sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient") - - def test_completion_client_api_parity(self): - """Test AsyncCompletionClient has same methods as CompletionClient.""" - from dify_client import CompletionClient - - sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient") - - def test_workflow_client_api_parity(self): - """Test AsyncWorkflowClient has same methods as WorkflowClient.""" - from dify_client import WorkflowClient - - sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient") - - def test_workspace_client_api_parity(self): - """Test AsyncWorkspaceClient has same methods as WorkspaceClient.""" - from dify_client import WorkspaceClient - - sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient") - - def test_knowledge_base_client_api_parity(self): - """Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient.""" - from dify_client import KnowledgeBaseClient - - sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient") - - -class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase): - """Test async client with mocked httpx.AsyncClient.""" - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_client_initialization(self, mock_httpx_async_client): - """Test async client initializes with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - client = AsyncDifyClient("test-key", "https://api.dify.ai/v1") - - # Verify httpx.AsyncClient was called - mock_httpx_async_client.assert_called_once() - self.assertEqual(client.api_key, "test-key") - - await client.aclose() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_context_manager(self, mock_httpx_async_client): - """Test async context manager works.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - self.assertEqual(client.api_key, "test-key") - - # Verify aclose was called - mock_client_instance.aclose.assert_called_once() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_send_request(self, mock_httpx_async_client): - """Test async _send_request method.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - mock_response.status_code = 200 - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - response = await client._send_request("GET", "/test") - - # Verify request was called - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify parameters - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_chat_client(self, mock_httpx_async_client): - """Test AsyncChatClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json = AsyncMock(return_value={"answer": "Hello!"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncChatClient("test-key") as client: - response = await client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_completion_client(self, mock_httpx_async_client): - """Test AsyncCompletionClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Response"}' - mock_response.json = AsyncMock(return_value={"answer": "Response"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncCompletionClient("test-key") as client: - response = await client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workflow_client(self, mock_httpx_async_client): - """Test AsyncWorkflowClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkflowClient("test-key") as client: - response = await client.run({"input": "test"}, "blocking", "user123") - data = await response.json() - self.assertEqual(data["result"], "success") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workspace_client(self, mock_httpx_async_client): - """Test AsyncWorkspaceClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": []}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkspaceClient("test-key") as client: - response = await client.get_available_models("llm") - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_knowledge_base_client(self, mock_httpx_async_client): - """Test AsyncKnowledgeBaseClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": [], "total": 0}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncKnowledgeBaseClient("test-key") as client: - response = await client.list_datasets() - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_all_async_client_classes(self, mock_httpx_async_client): - """Test all async client classes work with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - clients = [ - AsyncDifyClient("key"), - AsyncChatClient("key"), - AsyncCompletionClient("key"), - AsyncWorkflowClient("key"), - AsyncWorkspaceClient("key"), - AsyncKnowledgeBaseClient("key"), - ] - - # Verify httpx.AsyncClient was called for each - self.assertEqual(mock_httpx_async_client.call_count, 6) - - # Clean up - for client in clients: - await client.aclose() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py deleted file mode 100644 index b0d2f8ba23..0000000000 --- a/sdks/python-client/tests/test_client.py +++ /dev/null @@ -1,489 +0,0 @@ -import os -import time -import unittest -from unittest.mock import Mock, patch, mock_open - -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, -) - -API_KEY = os.environ.get("API_KEY") -APP_ID = os.environ.get("APP_ID") -API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") -FILE_PATH_BASE = os.path.dirname(__file__) - - -class TestKnowledgeBaseClient(unittest.TestCase): - def setUp(self): - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) - self.dataset_id = "test-dataset-id" - self.document_id = "test-document-id" - self.segment_id = "test-segment-id" - self.batch_id = "test-batch-id" - - def _get_dataset_kb_client(self): - return KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) - - @patch("dify_client.client.httpx.Client") - def test_001_create_dataset(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.json.return_value = {"id": self.dataset_id, "name": "test_dataset"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Re-create client with mocked httpx - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - - response = self.knowledge_base_client.create_dataset(name="test_dataset") - data = response.json() - self.assertIn("id", data) - self.assertEqual("test_dataset", data["name"]) - - # the following tests require to be executed in order because they use - # the dataset/document/segment ids from the previous test - self._test_002_list_datasets() - self._test_003_create_document_by_text() - self._test_004_update_document_by_text() - self._test_006_update_document_by_file() - self._test_007_list_documents() - self._test_008_delete_document() - self._test_009_create_document_by_file() - self._test_010_add_segments() - self._test_011_query_segments() - self._test_012_update_document_segment() - self._test_013_delete_document_segment() - self._test_014_delete_dataset() - - def _test_002_list_datasets(self): - # Mock the response - using the already mocked client from test_001_create_dataset - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - self.knowledge_base_client._client.request.return_value = mock_response - - response = self.knowledge_base_client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - def _test_003_create_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_text("test_document", "test_text") - data = response.json() - self.assertIn("document", data) - - def _test_004_update_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_006_update_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_007_list_documents(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.list_documents() - data = response.json() - self.assertIn("data", data) - - def _test_008_delete_document(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document(self.document_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_009_create_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_file(self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - - def _test_010_add_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_011_query_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.query_segments(self.document_id) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_012_update_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": {"id": self.segment_id, "content": "test text segment 1 updated"}} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_segment( - self.document_id, - self.segment_id, - {"content": "test text segment 1 updated"}, - ) - data = response.json() - self.assertIn("data", data) - self.assertEqual("test text segment 1 updated", data["data"]["content"]) - - def _test_013_delete_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document_segment(self.document_id, self.segment_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_014_delete_dataset(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.status_code = 204 - client._client.request.return_value = mock_response - - response = client.delete_dataset() - self.assertEqual(204, response.status_code) - - -class TestChatClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.chat_client = ChatClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.create_chat_message({}, "Hello, World!", "test_user") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test image description."}' - mock_response.json.return_value = {"answer": "I can see this is a test image description."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test uploaded image."}' - mock_response.json.return_value = {"answer": "I can see this is a test uploaded image."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversation_messages(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Here are the conversation messages."}' - mock_response.json.return_value = {"answer": "Here are the conversation messages."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversation_messages("test_user", "test-conversation-id") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversations(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"data": [{"id": "conv1", "name": "Test Conversation"}]}' - mock_response.json.return_value = {"data": [{"id": "conv1", "name": "Test Conversation"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversations("test_user") - self.assertIn("data", response.text) - - -class TestCompletionClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.completion_client = CompletionClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "This is a test completion response."}' - mock_response.json.return_value = {"answer": "This is a test completion response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "The weather today is sunny with a temperature of 75°F."}' - mock_response.json.return_value = {"answer": "The weather today is sunny with a temperature of 75°F."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - response = completion_client.create_completion_message( - {"query": "What's the weather like today?"}, "blocking", "test_user" - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test uploaded image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test uploaded image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - -class TestDifyClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.dify_client = DifyClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"result": "success"}' - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_message_feedback(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"success": true}' - mock_response.json.return_value = {"success": True} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.message_feedback("test-message-id", "like", "test_user") - self.assertIn("success", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_application_parameters(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"user_input_form": [{"field": "text", "label": "Input"}]}' - mock_response.json.return_value = {"user_input_form": [{"field": "text", "label": "Input"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.get_application_parameters("test_user") - self.assertIn("user_input_form", response.text) - - @patch("dify_client.client.httpx.Client") - @patch("builtins.open", new_callable=mock_open, read_data=b"fake image data") - def test_file_upload(self, mock_file_open, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"name": "panda.jpeg", "id": "test-file-id"}' - mock_response.json.return_value = {"name": "panda.jpeg", "id": "test-file-id"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - file_path = "/path/to/test/panda.jpeg" - file_name = "panda.jpeg" - mime_type = "image/jpeg" - - with open(file_path, "rb") as file: - files = {"file": (file_name, file, mime_type)} - response = dify_client.file_upload("test_user", files) - self.assertIn("name", response.text) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_exceptions.py b/sdks/python-client/tests/test_exceptions.py deleted file mode 100644 index eb44895749..0000000000 --- a/sdks/python-client/tests/test_exceptions.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Tests for custom exceptions.""" - -import unittest -from dify_client.exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, - DatasetError, - WorkflowError, -) - - -class TestExceptions(unittest.TestCase): - """Test custom exception classes.""" - - def test_base_exception(self): - """Test base DifyClientError.""" - error = DifyClientError("Test message", 500, {"error": "details"}) - self.assertEqual(str(error), "Test message") - self.assertEqual(error.status_code, 500) - self.assertEqual(error.response, {"error": "details"}) - - def test_api_error(self): - """Test APIError.""" - error = APIError("API failed", 400) - self.assertEqual(error.status_code, 400) - self.assertEqual(error.message, "API failed") - - def test_authentication_error(self): - """Test AuthenticationError.""" - error = AuthenticationError("Invalid API key") - self.assertEqual(str(error), "Invalid API key") - - def test_rate_limit_error(self): - """Test RateLimitError.""" - error = RateLimitError("Rate limited", retry_after=60) - self.assertEqual(error.retry_after, 60) - - error_default = RateLimitError() - self.assertEqual(error_default.retry_after, None) - - def test_validation_error(self): - """Test ValidationError.""" - error = ValidationError("Invalid parameter") - self.assertEqual(str(error), "Invalid parameter") - - def test_network_error(self): - """Test NetworkError.""" - error = NetworkError("Connection failed") - self.assertEqual(str(error), "Connection failed") - - def test_timeout_error(self): - """Test TimeoutError.""" - error = TimeoutError("Request timed out") - self.assertEqual(str(error), "Request timed out") - - def test_file_upload_error(self): - """Test FileUploadError.""" - error = FileUploadError("Upload failed") - self.assertEqual(str(error), "Upload failed") - - def test_dataset_error(self): - """Test DatasetError.""" - error = DatasetError("Dataset operation failed") - self.assertEqual(str(error), "Dataset operation failed") - - def test_workflow_error(self): - """Test WorkflowError.""" - error = WorkflowError("Workflow failed") - self.assertEqual(str(error), "Workflow failed") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_httpx_migration.py b/sdks/python-client/tests/test_httpx_migration.py deleted file mode 100644 index cf26de6eba..0000000000 --- a/sdks/python-client/tests/test_httpx_migration.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for httpx migration in the Python SDK. - -This test validates that the migration from requests to httpx maintains -backward compatibility and proper resource management. -""" - -import unittest -from unittest.mock import Mock, patch - -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - WorkspaceClient, - KnowledgeBaseClient, -) - - -class TestHttpxMigrationMocked(unittest.TestCase): - """Test cases for httpx migration with mocked requests.""" - - def setUp(self): - """Set up test fixtures.""" - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - - @patch("dify_client.client.httpx.Client") - def test_client_initialization(self, mock_httpx_client): - """Test that client initializes with httpx.Client.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - - # Verify httpx.Client was called with correct parameters - mock_httpx_client.assert_called_once() - call_kwargs = mock_httpx_client.call_args[1] - self.assertEqual(call_kwargs["base_url"], self.base_url) - - # Verify client properties - self.assertEqual(client.api_key, self.api_key) - self.assertEqual(client.base_url, self.base_url) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_context_manager_support(self, mock_httpx_client): - """Test that client works as context manager.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client: - self.assertEqual(client.api_key, self.api_key) - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_manual_close(self, mock_httpx_client): - """Test manual close() method.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - client.close() - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_send_request_httpx_compatibility(self, mock_httpx_client): - """Test _send_request uses httpx.Client.request properly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test-endpoint") - - # Verify httpx.Client.request was called correctly - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify method and endpoint - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test-endpoint") - - # Verify headers contain authorization - headers = call_args[1]["headers"] - self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}") - self.assertEqual(headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_response_compatibility(self, mock_httpx_client): - """Test httpx.Response is compatible with requests.Response API.""" - mock_response = Mock() - mock_response.json.return_value = {"key": "value"} - mock_response.text = '{"key": "value"}' - mock_response.content = b'{"key": "value"}' - mock_response.status_code = 200 - mock_response.headers = {"Content-Type": "application/json"} - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test") - - # Verify all common response methods work - self.assertEqual(response.json(), {"key": "value"}) - self.assertEqual(response.text, '{"key": "value"}') - self.assertEqual(response.content, b'{"key": "value"}') - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_all_client_classes_use_httpx(self, mock_httpx_client): - """Test that all client classes properly use httpx.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - clients = [ - DifyClient(self.api_key, self.base_url), - ChatClient(self.api_key, self.base_url), - CompletionClient(self.api_key, self.base_url), - WorkflowClient(self.api_key, self.base_url), - WorkspaceClient(self.api_key, self.base_url), - KnowledgeBaseClient(self.api_key, self.base_url), - ] - - # Verify httpx.Client was called for each client - self.assertEqual(mock_httpx_client.call_count, 6) - - # Clean up - for client in clients: - client.close() - - @patch("dify_client.client.httpx.Client") - def test_json_parameter_handling(self, mock_httpx_client): - """Test that json parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_data = {"key": "value", "number": 123} - - client._send_request("POST", "/test", json=test_data) - - # Verify json parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["json"], test_data) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_params_parameter_handling(self, mock_httpx_client): - """Test that params parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_params = {"page": 1, "limit": 20} - - client._send_request("GET", "/test", params=test_params) - - # Verify params parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["params"], test_params) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_inheritance_chain(self, mock_httpx_client): - """Test that inheritance chain is maintained.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - # ChatClient inherits from DifyClient - chat_client = ChatClient(self.api_key, self.base_url) - self.assertIsInstance(chat_client, DifyClient) - - # CompletionClient inherits from DifyClient - completion_client = CompletionClient(self.api_key, self.base_url) - self.assertIsInstance(completion_client, DifyClient) - - # WorkflowClient inherits from DifyClient - workflow_client = WorkflowClient(self.api_key, self.base_url) - self.assertIsInstance(workflow_client, DifyClient) - - # Clean up - chat_client.close() - completion_client.close() - workflow_client.close() - - @patch("dify_client.client.httpx.Client") - def test_nested_context_managers(self, mock_httpx_client): - """Test nested context managers work correctly.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client1: - with ChatClient(self.api_key, self.base_url) as client2: - self.assertEqual(client1.api_key, self.api_key) - self.assertEqual(client2.api_key, self.api_key) - - # Both close methods should have been called - self.assertEqual(mock_client_instance.close.call_count, 2) - - -class TestChatClientHttpx(unittest.TestCase): - """Test ChatClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_httpx(self, mock_httpx_client): - """Test create_chat_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json.return_value = {"answer": "Hello!"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with ChatClient("test-key") as client: - response = client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - self.assertEqual(response.json()["answer"], "Hello!") - - -class TestCompletionClientHttpx(unittest.TestCase): - """Test CompletionClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_httpx(self, mock_httpx_client): - """Test create_completion_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Response"}' - mock_response.json.return_value = {"answer": "Response"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with CompletionClient("test-key") as client: - response = client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - -class TestKnowledgeBaseClientHttpx(unittest.TestCase): - """Test KnowledgeBaseClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_list_datasets_httpx(self, mock_httpx_client): - """Test list_datasets works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with KnowledgeBaseClient("test-key") as client: - response = client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - -class TestWorkflowClientHttpx(unittest.TestCase): - """Test WorkflowClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_run_workflow_httpx(self, mock_httpx_client): - """Test run workflow works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkflowClient("test-key") as client: - response = client.run({"input": "test"}, "blocking", "user123") - self.assertEqual(response.json()["result"], "success") - - -class TestWorkspaceClientHttpx(unittest.TestCase): - """Test WorkspaceClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_get_available_models_httpx(self, mock_httpx_client): - """Test get_available_models works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkspaceClient("test-key") as client: - response = client.get_available_models("llm") - self.assertIn("data", response.json()) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_integration.py b/sdks/python-client/tests/test_integration.py deleted file mode 100644 index 6f38c5de56..0000000000 --- a/sdks/python-client/tests/test_integration.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Integration tests with proper mocking.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import json -import httpx -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - KnowledgeBaseClient, - WorkspaceClient, -) -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, -) - - -class TestDifyClientIntegration(unittest.TestCase): - """Integration tests for DifyClient with mocked HTTP responses.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient(api_key=self.api_key, base_url=self.base_url, enable_logging=False) - - @patch("httpx.Client.request") - def test_get_app_info_integration(self, mock_request): - """Test get_app_info integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "app_123", - "name": "Test App", - "description": "A test application", - "mode": "chat", - } - mock_request.return_value = mock_response - - response = self.client.get_app_info() - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "app_123") - self.assertEqual(data["name"], "Test App") - mock_request.assert_called_once_with( - "GET", - "/info", - json=None, - params=None, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_get_application_parameters_integration(self, mock_request): - """Test get_application_parameters integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "opening_statement": "Hello! How can I help you?", - "suggested_questions": ["What is AI?", "How does this work?"], - "speech_to_text": {"enabled": True}, - "text_to_speech": {"enabled": False}, - } - mock_request.return_value = mock_response - - response = self.client.get_application_parameters("user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["opening_statement"], "Hello! How can I help you?") - self.assertEqual(len(data["suggested_questions"]), 2) - mock_request.assert_called_once_with( - "GET", - "/parameters", - json=None, - params={"user": "user_123"}, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_file_upload_integration(self, mock_request): - """Test file_upload integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "file_123", - "name": "test.txt", - "size": 1024, - "mime_type": "text/plain", - } - mock_request.return_value = mock_response - - files = {"file": ("test.txt", "test content", "text/plain")} - response = self.client.file_upload("user_123", files) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "file_123") - self.assertEqual(data["name"], "test.txt") - - @patch("httpx.Client.request") - def test_message_feedback_integration(self, mock_request): - """Test message_feedback integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"success": True} - mock_request.return_value = mock_response - - response = self.client.message_feedback("msg_123", "like", "user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertTrue(data["success"]) - mock_request.assert_called_once_with( - "POST", - "/messages/msg_123/feedbacks", - json={"rating": "like", "user": "user_123"}, - params=None, - headers={ - "Authorization": "Bearer test_api_key", - "Content-Type": "application/json", - }, - ) - - -class TestChatClientIntegration(unittest.TestCase): - """Integration tests for ChatClient.""" - - def setUp(self): - self.client = ChatClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_chat_message_blocking(self, mock_request): - """Test create_chat_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "msg_123", - "answer": "Hello! How can I help you today?", - "conversation_id": "conv_123", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_chat_message( - inputs={"query": "Hello"}, - query="Hello, AI!", - user="user_123", - response_mode="blocking", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "Hello! How can I help you today?") - self.assertEqual(data["conversation_id"], "conv_123") - - @patch("httpx.Client.request") - def test_create_chat_message_streaming(self, mock_request): - """Test create_chat_message with streaming response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.iter_lines.return_value = [ - b'data: {"answer": "Hello"}', - b'data: {"answer": " world"}', - b'data: {"answer": "!"}', - ] - mock_request.return_value = mock_response - - response = self.client.create_chat_message(inputs={}, query="Hello", user="user_123", response_mode="streaming") - - self.assertEqual(response.status_code, 200) - lines = list(response.iter_lines()) - self.assertEqual(len(lines), 3) - - @patch("httpx.Client.request") - def test_get_conversations_integration(self, mock_request): - """Test get_conversations integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "conv_1", "name": "Conversation 1"}, - {"id": "conv_2", "name": "Conversation 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_conversations("user_123", limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["name"], "Conversation 1") - - @patch("httpx.Client.request") - def test_get_conversation_messages_integration(self, mock_request): - """Test get_conversation_messages integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "msg_1", "role": "user", "content": "Hello"}, - {"id": "msg_2", "role": "assistant", "content": "Hi there!"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_conversation_messages("user_123", conversation_id="conv_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["role"], "user") - - -class TestCompletionClientIntegration(unittest.TestCase): - """Integration tests for CompletionClient.""" - - def setUp(self): - self.client = CompletionClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_completion_message_blocking(self, mock_request): - """Test create_completion_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_123", - "answer": "This is a completion response.", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_completion_message( - inputs={"prompt": "Complete this sentence"}, - response_mode="blocking", - user="user_123", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "This is a completion response.") - - @patch("httpx.Client.request") - def test_create_completion_message_with_files(self, mock_request): - """Test create_completion_message with files.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_124", - "answer": "I can see the image shows...", - "files": [{"id": "file_1", "type": "image"}], - } - mock_request.return_value = mock_response - - files = { - "file": { - "type": "image", - "transfer_method": "remote_url", - "url": "https://example.com/image.jpg", - } - } - response = self.client.create_completion_message( - inputs={"prompt": "Describe this image"}, - response_mode="blocking", - user="user_123", - files=files, - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertIn("image", data["answer"]) - self.assertEqual(len(data["files"]), 1) - - -class TestWorkflowClientIntegration(unittest.TestCase): - """Integration tests for WorkflowClient.""" - - def setUp(self): - self.client = WorkflowClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_run_workflow_blocking(self, mock_request): - """Test run workflow with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "run_123", - "workflow_id": "workflow_123", - "status": "succeeded", - "inputs": {"query": "Test input"}, - "outputs": {"result": "Test output"}, - "elapsed_time": 2.5, - } - mock_request.return_value = mock_response - - response = self.client.run(inputs={"query": "Test input"}, response_mode="blocking", user="user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["status"], "succeeded") - self.assertEqual(data["outputs"]["result"], "Test output") - - @patch("httpx.Client.request") - def test_get_workflow_logs(self, mock_request): - """Test get_workflow_logs integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "logs": [ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - "total": 2, - "page": 1, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_workflow_logs(page=1, limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["logs"]), 2) - self.assertEqual(data["logs"][0]["status"], "succeeded") - - -class TestKnowledgeBaseClientIntegration(unittest.TestCase): - """Integration tests for KnowledgeBaseClient.""" - - def setUp(self): - self.client = KnowledgeBaseClient("test_api_key") - - @patch("httpx.Client.request") - def test_create_dataset(self, mock_request): - """Test create_dataset integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "dataset_123", - "name": "Test Dataset", - "description": "A test dataset", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_dataset(name="Test Dataset") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["name"], "Test Dataset") - self.assertEqual(data["id"], "dataset_123") - - @patch("httpx.Client.request") - def test_list_datasets(self, mock_request): - """Test list_datasets integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "dataset_1", "name": "Dataset 1"}, - {"id": "dataset_2", "name": "Dataset 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.list_datasets(page=1, page_size=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - - @patch("httpx.Client.request") - def test_create_document_by_text(self, mock_request): - """Test create_document_by_text integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "document": { - "id": "doc_123", - "name": "Test Document", - "word_count": 100, - "status": "indexing", - } - } - mock_request.return_value = mock_response - - # Mock dataset_id - self.client.dataset_id = "dataset_123" - - response = self.client.create_document_by_text(name="Test Document", text="This is test document content.") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["document"]["name"], "Test Document") - self.assertEqual(data["document"]["word_count"], 100) - - -class TestWorkspaceClientIntegration(unittest.TestCase): - """Integration tests for WorkspaceClient.""" - - def setUp(self): - self.client = WorkspaceClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_get_available_models(self, mock_request): - """Test get_available_models integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_available_models("llm") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["models"]), 2) - self.assertEqual(data["models"][0]["id"], "gpt-4") - - -class TestErrorScenariosIntegration(unittest.TestCase): - """Integration tests for error scenarios.""" - - def setUp(self): - self.client = DifyClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error_integration(self, mock_request): - """Test authentication error in integration.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error_integration(self, mock_request): - """Test rate limit error in integration.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_server_error_with_retry_integration(self, mock_request): - """Test server error with retry in integration.""" - # API errors don't retry by design - only network/timeout errors retry - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with patch("time.sleep"): # Skip actual sleep - with self.assertRaises(APIError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_validation_error_integration(self, mock_request): - """Test validation error in integration.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = { - "message": "Validation failed", - "details": {"field": "query", "error": "required"}, - } - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Validation failed") - self.assertEqual(context.exception.status_code, 422) - - -class TestContextManagerIntegration(unittest.TestCase): - """Integration tests for context manager usage.""" - - @patch("httpx.Client.close") - @patch("httpx.Client.request") - def test_context_manager_usage(self, mock_request, mock_close): - """Test context manager properly closes connections.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"id": "app_123", "name": "Test App"} - mock_request.return_value = mock_response - - with DifyClient("test_api_key") as client: - response = client.get_app_info() - self.assertEqual(response.status_code, 200) - - # Verify close was called - mock_close.assert_called_once() - - @patch("httpx.Client.close") - def test_manual_close(self, mock_close): - """Test manual close method.""" - client = DifyClient("test_api_key") - client.close() - mock_close.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_models.py b/sdks/python-client/tests/test_models.py deleted file mode 100644 index db9d92ad5b..0000000000 --- a/sdks/python-client/tests/test_models.py +++ /dev/null @@ -1,640 +0,0 @@ -"""Unit tests for response models.""" - -import unittest -import json -from datetime import datetime -from dify_client.models import ( - BaseResponse, - ErrorResponse, - FileInfo, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -) - - -class TestResponseModels(unittest.TestCase): - """Test cases for response model classes.""" - - def test_base_response(self): - """Test BaseResponse model.""" - response = BaseResponse(success=True, message="Operation successful") - self.assertTrue(response.success) - self.assertEqual(response.message, "Operation successful") - - def test_base_response_defaults(self): - """Test BaseResponse with default values.""" - response = BaseResponse(success=True) - self.assertTrue(response.success) - self.assertIsNone(response.message) - - def test_error_response(self): - """Test ErrorResponse model.""" - response = ErrorResponse( - success=False, - message="Error occurred", - error_code="VALIDATION_ERROR", - details={"field": "invalid_value"}, - ) - self.assertFalse(response.success) - self.assertEqual(response.message, "Error occurred") - self.assertEqual(response.error_code, "VALIDATION_ERROR") - self.assertEqual(response.details["field"], "invalid_value") - - def test_file_info(self): - """Test FileInfo model.""" - now = datetime.now() - file_info = FileInfo( - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/file.txt", - created_at=now, - ) - self.assertEqual(file_info.id, "file_123") - self.assertEqual(file_info.name, "test.txt") - self.assertEqual(file_info.size, 1024) - self.assertEqual(file_info.mime_type, "text/plain") - self.assertEqual(file_info.url, "https://example.com/file.txt") - self.assertEqual(file_info.created_at, now) - - def test_message_response(self): - """Test MessageResponse model.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - created_at=1234567890, - metadata={"model": "gpt-4"}, - files=[{"id": "file_1", "type": "image"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "msg_123") - self.assertEqual(response.answer, "Hello, world!") - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["model"], "gpt-4") - self.assertEqual(response.files[0]["id"], "file_1") - - def test_conversation_response(self): - """Test ConversationResponse model.""" - response = ConversationResponse( - success=True, - id="conv_123", - name="Test Conversation", - inputs={"query": "Hello"}, - status="active", - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "conv_123") - self.assertEqual(response.name, "Test Conversation") - self.assertEqual(response.inputs["query"], "Hello") - self.assertEqual(response.status, "active") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_response(self): - """Test DatasetResponse model.""" - response = DatasetResponse( - success=True, - id="dataset_123", - name="Test Dataset", - description="A test dataset", - permission="read", - indexing_technique="high_quality", - embedding_model="text-embedding-ada-002", - embedding_model_provider="openai", - retrieval_model={"search_type": "semantic"}, - document_count=10, - word_count=5000, - app_count=2, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "dataset_123") - self.assertEqual(response.name, "Test Dataset") - self.assertEqual(response.description, "A test dataset") - self.assertEqual(response.permission, "read") - self.assertEqual(response.indexing_technique, "high_quality") - self.assertEqual(response.embedding_model, "text-embedding-ada-002") - self.assertEqual(response.embedding_model_provider, "openai") - self.assertEqual(response.retrieval_model["search_type"], "semantic") - self.assertEqual(response.document_count, 10) - self.assertEqual(response.word_count, 5000) - self.assertEqual(response.app_count, 2) - - def test_document_response(self): - """Test DocumentResponse model.""" - response = DocumentResponse( - success=True, - id="doc_123", - name="test_document.txt", - data_source_type="upload_file", - position=1, - enabled=True, - word_count=1000, - hit_count=5, - doc_form="text_model", - created_at=1234567890.0, - indexing_status="completed", - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "doc_123") - self.assertEqual(response.name, "test_document.txt") - self.assertEqual(response.data_source_type, "upload_file") - self.assertEqual(response.position, 1) - self.assertTrue(response.enabled) - self.assertEqual(response.word_count, 1000) - self.assertEqual(response.hit_count, 5) - self.assertEqual(response.doc_form, "text_model") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.indexing_status, "completed") - self.assertEqual(response.completed_at, 1234567891.0) - - def test_document_segment_response(self): - """Test DocumentSegmentResponse model.""" - response = DocumentSegmentResponse( - success=True, - id="seg_123", - position=1, - document_id="doc_123", - content="This is a test segment.", - answer="Test answer", - word_count=5, - tokens=10, - keywords=["test", "segment"], - hit_count=2, - enabled=True, - status="completed", - created_at=1234567890.0, - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "seg_123") - self.assertEqual(response.position, 1) - self.assertEqual(response.document_id, "doc_123") - self.assertEqual(response.content, "This is a test segment.") - self.assertEqual(response.answer, "Test answer") - self.assertEqual(response.word_count, 5) - self.assertEqual(response.tokens, 10) - self.assertEqual(response.keywords, ["test", "segment"]) - self.assertEqual(response.hit_count, 2) - self.assertTrue(response.enabled) - self.assertEqual(response.status, "completed") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.completed_at, 1234567891.0) - - def test_workflow_run_response(self): - """Test WorkflowRunResponse model.""" - response = WorkflowRunResponse( - success=True, - id="run_123", - workflow_id="workflow_123", - status="succeeded", - inputs={"query": "test"}, - outputs={"answer": "result"}, - elapsed_time=5.5, - total_tokens=100, - total_steps=3, - created_at=1234567890.0, - finished_at=1234567895.5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "run_123") - self.assertEqual(response.workflow_id, "workflow_123") - self.assertEqual(response.status, "succeeded") - self.assertEqual(response.inputs["query"], "test") - self.assertEqual(response.outputs["answer"], "result") - self.assertEqual(response.elapsed_time, 5.5) - self.assertEqual(response.total_tokens, 100) - self.assertEqual(response.total_steps, 3) - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.finished_at, 1234567895.5) - - def test_application_parameters_response(self): - """Test ApplicationParametersResponse model.""" - response = ApplicationParametersResponse( - success=True, - opening_statement="Hello! How can I help you?", - suggested_questions=["What is AI?", "How does this work?"], - speech_to_text={"enabled": True}, - text_to_speech={"enabled": False, "voice": "alloy"}, - retriever_resource={"enabled": True}, - sensitive_word_avoidance={"enabled": False}, - file_upload={"enabled": True, "file_size_limit": 10485760}, - system_parameters={"max_tokens": 1000}, - user_input_form=[{"type": "text", "label": "Query"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.opening_statement, "Hello! How can I help you?") - self.assertEqual(response.suggested_questions, ["What is AI?", "How does this work?"]) - self.assertTrue(response.speech_to_text["enabled"]) - self.assertFalse(response.text_to_speech["enabled"]) - self.assertEqual(response.text_to_speech["voice"], "alloy") - self.assertTrue(response.retriever_resource["enabled"]) - self.assertFalse(response.sensitive_word_avoidance["enabled"]) - self.assertTrue(response.file_upload["enabled"]) - self.assertEqual(response.file_upload["file_size_limit"], 10485760) - self.assertEqual(response.system_parameters["max_tokens"], 1000) - self.assertEqual(response.user_input_form[0]["type"], "text") - - def test_annotation_response(self): - """Test AnnotationResponse model.""" - response = AnnotationResponse( - success=True, - id="annotation_123", - question="What is the capital of France?", - answer="Paris", - content="Additional context", - created_at=1234567890.0, - updated_at=1234567891.0, - created_by="user_123", - updated_by="user_123", - hit_count=5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "annotation_123") - self.assertEqual(response.question, "What is the capital of France?") - self.assertEqual(response.answer, "Paris") - self.assertEqual(response.content, "Additional context") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.updated_at, 1234567891.0) - self.assertEqual(response.created_by, "user_123") - self.assertEqual(response.updated_by, "user_123") - self.assertEqual(response.hit_count, 5) - - def test_paginated_response(self): - """Test PaginatedResponse model.""" - response = PaginatedResponse( - success=True, - data=[{"id": 1}, {"id": 2}, {"id": 3}], - has_more=True, - limit=10, - total=100, - page=1, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]["id"], 1) - self.assertTrue(response.has_more) - self.assertEqual(response.limit, 10) - self.assertEqual(response.total, 100) - self.assertEqual(response.page, 1) - - def test_conversation_variable_response(self): - """Test ConversationVariableResponse model.""" - response = ConversationVariableResponse( - success=True, - conversation_id="conv_123", - variables=[ - {"id": "var_1", "name": "user_name", "value": "John"}, - {"id": "var_2", "name": "preferences", "value": {"theme": "dark"}}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(len(response.variables), 2) - self.assertEqual(response.variables[0]["name"], "user_name") - self.assertEqual(response.variables[0]["value"], "John") - self.assertEqual(response.variables[1]["name"], "preferences") - self.assertEqual(response.variables[1]["value"]["theme"], "dark") - - def test_file_upload_response(self): - """Test FileUploadResponse model.""" - response = FileUploadResponse( - success=True, - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/files/test.txt", - created_at=1234567890.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "test.txt") - self.assertEqual(response.size, 1024) - self.assertEqual(response.mime_type, "text/plain") - self.assertEqual(response.url, "https://example.com/files/test.txt") - self.assertEqual(response.created_at, 1234567890.0) - - def test_audio_response(self): - """Test AudioResponse model.""" - response = AudioResponse( - success=True, - audio="base64_encoded_audio_data", - audio_url="https://example.com/audio.mp3", - duration=10.5, - sample_rate=44100, - ) - self.assertTrue(response.success) - self.assertEqual(response.audio, "base64_encoded_audio_data") - self.assertEqual(response.audio_url, "https://example.com/audio.mp3") - self.assertEqual(response.duration, 10.5) - self.assertEqual(response.sample_rate, 44100) - - def test_suggested_questions_response(self): - """Test SuggestedQuestionsResponse model.""" - response = SuggestedQuestionsResponse( - success=True, - message_id="msg_123", - questions=[ - "What is machine learning?", - "How does AI work?", - "Can you explain neural networks?", - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.message_id, "msg_123") - self.assertEqual(len(response.questions), 3) - self.assertEqual(response.questions[0], "What is machine learning?") - - def test_app_info_response(self): - """Test AppInfoResponse model.""" - response = AppInfoResponse( - success=True, - id="app_123", - name="Test App", - description="A test application", - icon="🤖", - icon_background="#FF6B6B", - mode="chat", - tags=["AI", "Chat", "Test"], - enable_site=True, - enable_api=True, - api_token="app_token_123", - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "app_123") - self.assertEqual(response.name, "Test App") - self.assertEqual(response.description, "A test application") - self.assertEqual(response.icon, "🤖") - self.assertEqual(response.icon_background, "#FF6B6B") - self.assertEqual(response.mode, "chat") - self.assertEqual(response.tags, ["AI", "Chat", "Test"]) - self.assertTrue(response.enable_site) - self.assertTrue(response.enable_api) - self.assertEqual(response.api_token, "app_token_123") - - def test_workspace_models_response(self): - """Test WorkspaceModelsResponse model.""" - response = WorkspaceModelsResponse( - success=True, - models=[ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertEqual(response.models[0]["name"], "GPT-4") - self.assertEqual(response.models[0]["provider"], "openai") - - def test_hit_testing_response(self): - """Test HitTestingResponse model.""" - response = HitTestingResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is a subset of AI...", "score": 0.95}, - {"content": "ML algorithms learn from data...", "score": 0.87}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.records[0]["score"], 0.95) - - def test_dataset_tags_response(self): - """Test DatasetTagsResponse model.""" - response = DatasetTagsResponse( - success=True, - tags=[ - {"id": "tag_1", "name": "Technology", "color": "#FF0000"}, - {"id": "tag_2", "name": "Science", "color": "#00FF00"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.tags), 2) - self.assertEqual(response.tags[0]["name"], "Technology") - self.assertEqual(response.tags[0]["color"], "#FF0000") - - def test_workflow_logs_response(self): - """Test WorkflowLogsResponse model.""" - response = WorkflowLogsResponse( - success=True, - logs=[ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - total=50, - page=1, - limit=10, - has_more=True, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.logs), 2) - self.assertEqual(response.logs[0]["status"], "succeeded") - self.assertEqual(response.total, 50) - self.assertEqual(response.page, 1) - self.assertEqual(response.limit, 10) - self.assertTrue(response.has_more) - - def test_model_serialization(self): - """Test that models can be serialized to JSON.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - ) - - # Convert to dict and then to JSON - response_dict = { - "success": response.success, - "id": response.id, - "answer": response.answer, - "conversation_id": response.conversation_id, - } - - json_str = json.dumps(response_dict) - parsed = json.loads(json_str) - - self.assertTrue(parsed["success"]) - self.assertEqual(parsed["id"], "msg_123") - self.assertEqual(parsed["answer"], "Hello, world!") - self.assertEqual(parsed["conversation_id"], "conv_123") - - # Tests for new response models - def test_model_provider_response(self): - """Test ModelProviderResponse model.""" - response = ModelProviderResponse( - success=True, - provider_name="openai", - provider_type="llm", - models=[ - {"id": "gpt-4", "name": "GPT-4", "max_tokens": 8192}, - {"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo", "max_tokens": 4096}, - ], - is_enabled=True, - credentials={"api_key": "sk-..."}, - ) - self.assertTrue(response.success) - self.assertEqual(response.provider_name, "openai") - self.assertEqual(response.provider_type, "llm") - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertTrue(response.is_enabled) - self.assertEqual(response.credentials["api_key"], "sk-...") - - def test_file_info_response(self): - """Test FileInfoResponse model.""" - response = FileInfoResponse( - success=True, - id="file_123", - name="document.pdf", - size=2048576, - mime_type="application/pdf", - url="https://example.com/files/document.pdf", - created_at=1234567890, - metadata={"pages": 10, "author": "John Doe"}, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "document.pdf") - self.assertEqual(response.size, 2048576) - self.assertEqual(response.mime_type, "application/pdf") - self.assertEqual(response.url, "https://example.com/files/document.pdf") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["pages"], 10) - - def test_workflow_draft_response(self): - """Test WorkflowDraftResponse model.""" - response = WorkflowDraftResponse( - success=True, - id="draft_123", - app_id="app_456", - draft_data={"nodes": [], "edges": [], "config": {"name": "Test Workflow"}}, - version=1, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "draft_123") - self.assertEqual(response.app_id, "app_456") - self.assertEqual(response.draft_data["config"]["name"], "Test Workflow") - self.assertEqual(response.version, 1) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_api_token_response(self): - """Test ApiTokenResponse model.""" - response = ApiTokenResponse( - success=True, - id="token_123", - name="Production Token", - token="app-xxxxxxxxxxxx", - description="Token for production environment", - created_at=1234567890, - last_used_at=1234567891, - is_active=True, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "token_123") - self.assertEqual(response.name, "Production Token") - self.assertEqual(response.token, "app-xxxxxxxxxxxx") - self.assertEqual(response.description, "Token for production environment") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.last_used_at, 1234567891) - self.assertTrue(response.is_active) - - def test_job_status_response(self): - """Test JobStatusResponse model.""" - response = JobStatusResponse( - success=True, - job_id="job_123", - job_status="running", - error_msg=None, - progress=0.75, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.job_id, "job_123") - self.assertEqual(response.job_status, "running") - self.assertIsNone(response.error_msg) - self.assertEqual(response.progress, 0.75) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_query_response(self): - """Test DatasetQueryResponse model.""" - response = DatasetQueryResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is...", "score": 0.95}, - {"content": "ML algorithms...", "score": 0.87}, - ], - total=2, - search_time=0.123, - retrieval_model={"method": "semantic_search", "top_k": 3}, - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.total, 2) - self.assertEqual(response.search_time, 0.123) - self.assertEqual(response.retrieval_model["method"], "semantic_search") - - def test_dataset_template_response(self): - """Test DatasetTemplateResponse model.""" - response = DatasetTemplateResponse( - success=True, - template_name="customer_support", - display_name="Customer Support", - description="Template for customer support knowledge base", - category="support", - icon="🎧", - config_schema={"fields": [{"name": "category", "type": "string"}]}, - ) - self.assertTrue(response.success) - self.assertEqual(response.template_name, "customer_support") - self.assertEqual(response.display_name, "Customer Support") - self.assertEqual(response.description, "Template for customer support knowledge base") - self.assertEqual(response.category, "support") - self.assertEqual(response.icon, "🎧") - self.assertEqual(response.config_schema["fields"][0]["name"], "category") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_retry_and_error_handling.py b/sdks/python-client/tests/test_retry_and_error_handling.py deleted file mode 100644 index bd415bde43..0000000000 --- a/sdks/python-client/tests/test_retry_and_error_handling.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Unit tests for retry mechanism and error handling.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import httpx -from dify_client.client import DifyClient -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, -) - - -class TestRetryMechanism(unittest.TestCase): - """Test cases for retry mechanism.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient( - api_key=self.api_key, - base_url=self.base_url, - max_retries=3, - retry_delay=0.1, # Short delay for tests - enable_logging=False, - ) - - @patch("httpx.Client.request") - def test_successful_request_no_retry(self, mock_request): - """Test that successful requests don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - mock_request.return_value = mock_response - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response, mock_response) - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_network_error(self, mock_sleep, mock_request): - """Test retry on network errors.""" - # First two calls raise network error, third succeeds - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - Mock(status_code=200, content=b'{"success": true}'), - ] - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_timeout_error(self, mock_sleep, mock_request): - """Test retry on timeout errors.""" - mock_request.side_effect = [ - httpx.TimeoutException("Request timed out"), - httpx.TimeoutException("Request timed out"), - Mock(status_code=200, content=b'{"success": true}'), - ] - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_max_retries_exceeded(self, mock_sleep, mock_request): - """Test behavior when max retries are exceeded.""" - mock_request.side_effect = httpx.NetworkError("Persistent network error") - - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 4) # 1 initial + 3 retries - self.assertEqual(mock_sleep.call_count, 3) - - @patch("httpx.Client.request") - def test_no_retry_on_client_error(self, mock_request): - """Test that client errors (4xx) don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Unauthorized"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_retry_on_server_error(self, mock_request): - """Test that server errors (5xx) don't retry - they raise APIError immediately.""" - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - # Should not retry server errors - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_exponential_backoff(self, mock_request): - """Test exponential backoff timing.""" - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), # All attempts fail - ] - - with patch("time.sleep") as mock_sleep: - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - # Check exponential backoff: 0.1, 0.2, 0.4 - expected_calls = [0.1, 0.2, 0.4] - actual_calls = [call[0][0] for call in mock_sleep.call_args_list] - self.assertEqual(actual_calls, expected_calls) - - -class TestErrorHandling(unittest.TestCase): - """Test cases for error handling.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error(self, mock_request): - """Test AuthenticationError handling.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error(self, mock_request): - """Test RateLimitError handling.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_validation_error(self, mock_request): - """Test ValidationError handling.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = {"message": "Invalid parameters"} - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid parameters") - self.assertEqual(context.exception.status_code, 422) - - @patch("httpx.Client.request") - def test_api_error(self, mock_request): - """Test general APIError handling.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.json.return_value = {"message": "Internal server error"} - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - - @patch("httpx.Client.request") - def test_error_response_without_json(self, mock_request): - """Test error handling when response doesn't contain valid JSON.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.content = b"Internal Server Error" - mock_response.json.side_effect = ValueError("No JSON object could be decoded") - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "HTTP 500") - - @patch("httpx.Client.request") - def test_file_upload_error(self, mock_request): - """Test FileUploadError handling.""" - mock_response = Mock() - mock_response.status_code = 400 - mock_response.json.return_value = {"message": "File upload failed"} - mock_request.return_value = mock_response - - with self.assertRaises(FileUploadError) as context: - self.client._send_request_with_files("POST", "/upload", {}, {}) - - self.assertEqual(str(context.exception), "File upload failed") - self.assertEqual(context.exception.status_code, 400) - - -class TestParameterValidation(unittest.TestCase): - """Test cases for parameter validation.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - def test_empty_string_validation(self): - """Test validation of empty strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(empty_string="") - - def test_whitespace_only_string_validation(self): - """Test validation of whitespace-only strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(whitespace_string=" ") - - def test_long_string_validation(self): - """Test validation of overly long strings.""" - long_string = "a" * 10001 # Exceeds 10000 character limit - with self.assertRaises(ValidationError): - self.client._validate_params(long_string=long_string) - - def test_large_list_validation(self): - """Test validation of overly large lists.""" - large_list = list(range(1001)) # Exceeds 1000 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_list=large_list) - - def test_large_dict_validation(self): - """Test validation of overly large dictionaries.""" - large_dict = {f"key_{i}": i for i in range(101)} # Exceeds 100 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_dict=large_dict) - - def test_valid_parameters_pass(self): - """Test that valid parameters pass validation.""" - # Should not raise any exception - self.client._validate_params( - valid_string="Hello, World!", - valid_list=[1, 2, 3], - valid_dict={"key": "value"}, - none_value=None, - ) - - def test_message_feedback_validation(self): - """Test validation in message_feedback method.""" - with self.assertRaises(ValidationError): - self.client.message_feedback("msg_id", "invalid_rating", "user") - - def test_completion_message_validation(self): - """Test validation in create_completion_message method.""" - from dify_client.client import CompletionClient - - client = CompletionClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_completion_message( - inputs="not_a_dict", # Should be a dict - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - user="test_user", - ) - - def test_chat_message_validation(self): - """Test validation in create_chat_message method.""" - from dify_client.client import ChatClient - - client = ChatClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_chat_message( - inputs="not_a_dict", # Should be a dict - query="", # Should not be empty - user="test_user", - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/uv.lock b/sdks/python-client/uv.lock deleted file mode 100644 index 4a9d7d5193..0000000000 --- a/sdks/python-client/uv.lock +++ /dev/null @@ -1,307 +0,0 @@ -version = 1 -revision = 3 -requires-python = ">=3.10" - -[[package]] -name = "aiofiles" -version = "25.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" }, -] - -[[package]] -name = "anyio" -version = "4.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, -] - -[[package]] -name = "backports-asyncio-runner" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, -] - -[[package]] -name = "certifi" -version = "2025.10.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, -] - -[[package]] -name = "dify-client" -version = "0.1.12" -source = { editable = "." } -dependencies = [ - { name = "aiofiles" }, - { name = "httpx", extra = ["http2"] }, -] - -[package.optional-dependencies] -dev = [ - { name = "pytest" }, - { name = "pytest-asyncio" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiofiles", specifier = ">=23.0.0" }, - { name = "httpx", extras = ["http2"], specifier = ">=0.27.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, -] -provides-extras = ["dev"] - -[[package]] -name = "exceptiongroup" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, -] - -[[package]] -name = "h2" -version = "4.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "hpack" }, - { name = "hyperframe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, -] - -[[package]] -name = "hpack" -version = "4.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, -] - -[package.optional-dependencies] -http2 = [ - { name = "h2" }, -] - -[[package]] -name = "hyperframe" -version = "6.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, -] - -[[package]] -name = "iniconfig" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, -] - -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, -] - -[[package]] -name = "pygments" -version = "2.19.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, -] - -[[package]] -name = "pytest" -version = "8.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "pygments" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, -] - -[[package]] -name = "pytest-asyncio" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, - { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, -] - -[[package]] -name = "tomli" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, - { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, - { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, - { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, - { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, - { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, - { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, - { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, - { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, - { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, - { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, - { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, - { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, - { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, - { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, - { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, - { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, - { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, - { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, - { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, - { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, - { url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" }, - { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, - { url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" }, - { url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" }, - { url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" }, - { url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" }, - { url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" }, - { url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" }, - { url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" }, - { url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" }, - { url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" }, - { url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" }, - { url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" }, - { url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" }, - { url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" }, - { url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" }, - { url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" }, - { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, -] - -[[package]] -name = "typing-extensions" -version = "4.15.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, -] diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index b51c2ad94d..5716bfd92d 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -32,6 +32,7 @@ import { canFindTool } from '@/utils' import { useAllBuiltInTools, useAllCustomTools, useAllMCPTools, useAllWorkflowTools } from '@/service/use-tools' import type { ToolWithProvider } from '@/app/components/workflow/types' import { useMittContextSelector } from '@/context/mitt-context' +import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' type AgentToolWithMoreInfo = AgentTool & { icon: any; collection?: Collection } | null const AgentTools: FC = () => { @@ -93,13 +94,17 @@ const AgentTools: FC = () => { const [isDeleting, setIsDeleting] = useState(-1) const getToolValue = (tool: ToolDefaultValue) => { + const currToolInCollections = collectionList.find(c => c.id === tool.provider_id) + const currToolWithConfigs = currToolInCollections?.tools.find(t => t.name === tool.tool_name) + const formSchemas = currToolWithConfigs ? toolParametersToFormSchemas(currToolWithConfigs.parameters) : [] + const paramsWithDefaultValue = addDefaultValue(tool.params, formSchemas) return { provider_id: tool.provider_id, provider_type: tool.provider_type as CollectionType, provider_name: tool.provider_name, tool_name: tool.tool_name, tool_label: tool.tool_label, - tool_parameters: tool.params, + tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, } @@ -119,7 +124,7 @@ const AgentTools: FC = () => { } const getProviderShowName = (item: AgentTool) => { const type = item.provider_type - if(type === CollectionType.builtIn) + if (type === CollectionType.builtIn) return item.provider_name.split('/').pop() return item.provider_name } diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index ef28dd222c..c5947495db 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -16,7 +16,7 @@ import Description from '@/app/components/plugins/card/base/description' import TabSlider from '@/app/components/base/tab-slider-plain' import Button from '@/app/components/base/button' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' -import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import type { Collection, Tool } from '@/app/components/tools/types' import { CollectionType } from '@/app/components/tools/types' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList, fetchWorkflowToolList } from '@/service/tools' @@ -92,15 +92,11 @@ const SettingBuiltInTool: FC = ({ }()) }) setTools(list) - const currTool = list.find(tool => tool.name === toolName) - if (currTool) { - const formSchemas = toolParametersToFormSchemas(currTool.parameters) - setTempSetting(addDefaultValue(setting, formSchemas)) - } } catch { } setIsLoading(false) })() + // eslint-disable-next-line react-hooks/exhaustive-deps }, [collection?.name, collection?.id, collection?.type]) useEffect(() => { @@ -249,7 +245,7 @@ const SettingBuiltInTool: FC = ({ {!readonly && !isInfoActive && (
- +
)} diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index d21d35eeee..0ff375d815 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -816,9 +816,12 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) conversationDetailMutate() notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true @@ -861,9 +864,12 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true } diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 302fb9a3c7..94c80687ed 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -284,7 +284,6 @@ const ChatWrapper = () => { themeBuilder={themeBuilder} switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} - isMobile={isMobile} sidebarCollapseState={sidebarCollapseState} questionIcon={ initUserVariables?.avatar_url diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 51b5df4f32..0e947f8137 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -71,7 +71,6 @@ export type ChatProps = { onFeatureBarClick?: (state: boolean) => void noSpacing?: boolean inputDisabled?: boolean - isMobile?: boolean sidebarCollapseState?: boolean } @@ -110,7 +109,6 @@ const Chat: FC = ({ onFeatureBarClick, noSpacing, inputDisabled, - isMobile, sidebarCollapseState, }) => { const { t } = useTranslation() @@ -321,7 +319,6 @@ const Chat: FC = ({ ) } diff --git a/web/app/components/base/chat/chat/try-to-ask.tsx b/web/app/components/base/chat/chat/try-to-ask.tsx index 3fc690361e..665f7b3b13 100644 --- a/web/app/components/base/chat/chat/try-to-ask.tsx +++ b/web/app/components/base/chat/chat/try-to-ask.tsx @@ -8,12 +8,10 @@ import Divider from '@/app/components/base/divider' type TryToAskProps = { suggestedQuestions: string[] onSend: OnSend - isMobile?: boolean } const TryToAsk: FC = ({ suggestedQuestions, onSend, - isMobile, }) => { const { t } = useTranslation() diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index 5fba104d35..b0a880d78f 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -262,7 +262,6 @@ const ChatWrapper = () => { themeBuilder={themeBuilder} switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} - isMobile={isMobile} questionIcon={ initUserVariables?.avatar_url ? { const { t } = useTranslation() const { isCurrentWorkspaceManager } = useAppContext() const { enableBilling } = useProviderContext() - const { data: billingUrl } = useSWR( - (!enableBilling || !isCurrentWorkspaceManager) ? null : ['/billing/invoices'], - () => fetchBillingUrl().then(data => data.url), - ) + const { data: billingUrl, isFetching, refetch } = useBillingUrl(enableBilling && isCurrentWorkspaceManager) + + const handleOpenBilling = async () => { + // Open synchronously to preserve user gesture for popup blockers + if (billingUrl) { + window.open(billingUrl, '_blank', 'noopener,noreferrer') + return + } + + const newWindow = window.open('', '_blank', 'noopener,noreferrer') + try { + const url = (await refetch()).data + if (url && newWindow) { + newWindow.location.href = url + return + } + } + catch (err) { + console.error('Failed to fetch billing url', err) + } + // Close the placeholder window if we failed to fetch the URL + newWindow?.close() + } return ( ) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index 5d7b11c7e0..396dd4a1b0 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -8,7 +8,7 @@ import { ALL_PLANS } from '../../../config' import Toast from '../../../../base/toast' import { PlanRange } from '../../plan-switcher/plan-range-switcher' import { useAppContext } from '@/context/app-context' -import { fetchSubscriptionUrls } from '@/service/billing' +import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' import List from './list' import Button from './button' import { Professional, Sandbox, Team } from '../../assets' @@ -39,7 +39,8 @@ const CloudPlanItem: FC = ({ const planInfo = ALL_PLANS[plan] const isYear = planRange === PlanRange.yearly const isCurrent = plan === currentPlan - const isPlanDisabled = planInfo.level <= ALL_PLANS[currentPlan].level + const isCurrentPaidPlan = isCurrent && !isFreePlan + const isPlanDisabled = isCurrentPaidPlan ? false : planInfo.level <= ALL_PLANS[currentPlan].level const { isCurrentWorkspaceManager } = useAppContext() const btnText = useMemo(() => { @@ -60,10 +61,6 @@ const CloudPlanItem: FC = ({ if (isPlanDisabled) return - if (isFreePlan) - return - - // Only workspace manager can buy plan if (!isCurrentWorkspaceManager) { Toast.notify({ type: 'error', @@ -74,6 +71,15 @@ const CloudPlanItem: FC = ({ } setLoading(true) try { + if (isCurrentPaidPlan) { + const res = await fetchBillingUrl() + window.open(res.url, '_blank') + return + } + + if (isFreePlan) + return + const res = await fetchSubscriptionUrls(plan, isYear ? 'year' : 'month') // Adb Block additional tracking block the gtag, so we need to redirect directly window.location.href = res.url diff --git a/web/i18n/de-DE/billing.ts b/web/i18n/de-DE/billing.ts index 2e7897a20f..c0626a0bb8 100644 --- a/web/i18n/de-DE/billing.ts +++ b/web/i18n/de-DE/billing.ts @@ -199,6 +199,9 @@ const translation = { usageTitle: 'AUSLÖSEEREIGNISSE', description: 'Sie haben das Limit der Workflow-Ereignisauslöser für diesen Plan erreicht.', }, + viewBillingTitle: 'Abrechnung und Abonnements', + viewBillingDescription: 'Zahlungsmethoden, Rechnungen und Abonnementänderungen verwalten', + viewBillingAction: 'Verwalten', } export default translation diff --git a/web/i18n/en-US/billing.ts b/web/i18n/en-US/billing.ts index 7272214f41..2531e5831a 100644 --- a/web/i18n/en-US/billing.ts +++ b/web/i18n/en-US/billing.ts @@ -25,6 +25,9 @@ const translation = { encourageShort: 'Upgrade', }, viewBilling: 'Manage billing and subscriptions', + viewBillingTitle: 'Billing and Subscriptions', + viewBillingDescription: 'Manage payment methods, invoices, and subscription changes', + viewBillingAction: 'Manage', buyPermissionDeniedTip: 'Please contact your enterprise administrator to subscribe', plansCommon: { title: { diff --git a/web/i18n/es-ES/billing.ts b/web/i18n/es-ES/billing.ts index 20fb64e024..29652af07a 100644 --- a/web/i18n/es-ES/billing.ts +++ b/web/i18n/es-ES/billing.ts @@ -161,7 +161,7 @@ const translation = { includesTitle: 'Todo de Community, además:', name: 'Premium', for: 'Para organizaciones y equipos de tamaño mediano', - features: ['Confiabilidad autogestionada por varios proveedores de la nube', 'Espacio de trabajo único', 'Personalización de Logotipo y Marca de la Aplicación Web', 'Soporte prioritario por correo electrónico y chat'], + features: ['Confiabilidad Autogestionada por Diversos Proveedores de Nube', 'Espacio de trabajo único', 'Personalización de Logotipo y Marca de la Aplicación Web', 'Soporte prioritario por correo electrónico y chat'], }, }, vectorSpace: { @@ -199,6 +199,9 @@ const translation = { title: 'Actualiza para desbloquear más eventos desencadenantes', description: 'Has alcanzado el límite de activadores de eventos de flujo de trabajo para este plan.', }, + viewBillingTitle: 'Facturación y Suscripciones', + viewBillingDescription: 'Gestiona métodos de pago, facturas y cambios de suscripción', + viewBillingAction: 'Gestionar', } export default translation diff --git a/web/i18n/fa-IR/billing.ts b/web/i18n/fa-IR/billing.ts index 933c4edabf..387be157e6 100644 --- a/web/i18n/fa-IR/billing.ts +++ b/web/i18n/fa-IR/billing.ts @@ -141,7 +141,7 @@ const translation = { btnText: 'تماس با فروش', for: 'برای تیم‌های بزرگ', priceTip: 'فقط صورتحساب سالیانه', - features: ['راه‌حل‌های مستقرسازی مقیاس‌پذیر با سطح سازمانی', 'مجوز استفاده تجاری', 'ویژگی‌های اختصاصی سازمانی', 'چند فضای کاری و مدیریت سازمانی', 'ورود یکپارچه', 'توافق‌نامه‌های سطح خدمات مذاکره شده توسط شرکای Dify', 'امنیت و کنترل‌های پیشرفته', 'به‌روزرسانی‌ها و نگهداری به‌طور رسمی توسط دیفی', 'پشتیبانی فنی حرفه‌ای'], + features: ['راه‌حل‌های مستقرسازی مقیاس‌پذیر با سطح سازمانی', 'مجوز استفاده تجاری', 'ویژگی‌های اختصاصی سازمانی', 'چند فضای کاری و مدیریت سازمانی', 'ورود یکپارچه', 'توافق‌نامه‌های سطح خدمات مذاکره شده توسط شرکای Dify', 'امنیت و کنترل‌های پیشرفته', 'به‌روزرسانی‌ها و نگهداری توسط دیفی به‌طور رسمی', 'پشتیبانی فنی حرفه‌ای'], }, community: { btnText: 'شروع کنید با جامعه', @@ -161,7 +161,7 @@ const translation = { name: 'پیشرفته', priceTip: 'بر اساس بازار ابری', comingSoon: 'پشتیبانی مایکروسافت آژور و گوگل کلود به زودی در دسترس خواهد بود', - features: ['قابلیت اطمینان خودمدیریتی توسط ارائه‌دهندگان مختلف ابری', 'فضای کاری تنها', 'سفارشی‌سازی لوگو و برندینگ وب‌اپ', 'پشتیبانی اولویت‌دار ایمیل و چت'], + features: ['قابلیت اطمینان خودمدیریتی توسط ارائه‌دهندگان مختلف ابری', 'فضای کاری تنها', 'سفارشی‌سازی لوگو و برند وب‌اپ', 'پشتیبانی اولویت‌دار ایمیل و چت'], }, }, vectorSpace: { @@ -199,6 +199,9 @@ const translation = { title: 'ارتقا دهید تا رویدادهای محرک بیشتری باز شود', usageTitle: 'رویدادهای محرک', }, + viewBillingTitle: 'صورتحساب و اشتراک‌ها', + viewBillingDescription: 'مدیریت روش‌های پرداخت، صورت‌حساب‌ها و تغییرات اشتراک', + viewBillingAction: 'مدیریت', } export default translation diff --git a/web/i18n/fr-FR/billing.ts b/web/i18n/fr-FR/billing.ts index 0e1fe4c566..83ea177557 100644 --- a/web/i18n/fr-FR/billing.ts +++ b/web/i18n/fr-FR/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'Entreprise', description: 'Obtenez toutes les capacités et le support pour les systèmes à grande échelle et critiques pour la mission.', includesTitle: 'Tout ce qui est inclus dans le plan Équipe, plus :', - features: ['Solutions de déploiement évolutives de niveau entreprise', 'Autorisation de licence commerciale', 'Fonctionnalités exclusives pour les entreprises', 'Espaces de travail multiples et gestion d\'entreprise', 'SSO', 'Accords sur les SLA négociés par les partenaires Dify', 'Sécurité et Contrôles Avancés', 'Mises à jour et maintenance par Dify Officiellement', 'Assistance technique professionnelle'], + features: ['Solutions de déploiement évolutives de niveau entreprise', 'Autorisation de licence commerciale', 'Fonctionnalités exclusives pour les entreprises', 'Espaces de travail multiples et gestion d\'entreprise', 'SSO', 'Accords de niveau de service négociés par les partenaires de Dify', 'Sécurité et contrôles avancés', 'Mises à jour et maintenance par Dify Officiellement', 'Assistance technique professionnelle'], for: 'Pour les équipes de grande taille', btnText: 'Contacter les ventes', priceTip: 'Facturation Annuel Seulement', @@ -153,7 +153,7 @@ const translation = { description: 'Pour les utilisateurs individuels, les petites équipes ou les projets non commerciaux', }, premium: { - features: ['Fiabilité autonome par divers fournisseurs de cloud', 'Espace de travail unique', 'Personnalisation du logo et de l\'identité visuelle de l\'application web', 'Assistance prioritaire par e-mail et chat'], + features: ['Fiabilité autonome par divers fournisseurs de cloud', 'Espace de travail unique', 'Personnalisation du logo et de l\'image de marque de l\'application web', 'Assistance prioritaire par e-mail et chat'], for: 'Pour les organisations et les équipes de taille moyenne', includesTitle: 'Tout de la communauté, en plus :', name: 'Premium', @@ -199,6 +199,9 @@ const translation = { dismiss: 'Fermer', title: 'Mettez à niveau pour débloquer plus d\'événements déclencheurs', }, + viewBillingTitle: 'Facturation et abonnements', + viewBillingDescription: 'Gérer les méthodes de paiement, les factures et les modifications d\'abonnement', + viewBillingAction: 'Gérer', } export default translation diff --git a/web/i18n/hi-IN/billing.ts b/web/i18n/hi-IN/billing.ts index 43e576101e..52aa3a0305 100644 --- a/web/i18n/hi-IN/billing.ts +++ b/web/i18n/hi-IN/billing.ts @@ -210,6 +210,9 @@ const translation = { title: 'अधिक ट्रिगर इवेंट्स अनलॉक करने के लिए अपग्रेड करें', description: 'आप इस योजना के लिए वर्कफ़्लो इवेंट ट्रिगर्स की सीमा तक पहुँच चुके हैं।', }, + viewBillingTitle: 'बिलिंग और सब्सक्रिप्शन', + viewBillingDescription: 'भुगतान के तरीकों, चालानों और सदस्यता में बदलावों का प्रबंधन करें', + viewBillingAction: 'प्रबंध करना', } export default translation diff --git a/web/i18n/id-ID/billing.ts b/web/i18n/id-ID/billing.ts index 5d14cbdf38..640bf9aeb6 100644 --- a/web/i18n/id-ID/billing.ts +++ b/web/i18n/id-ID/billing.ts @@ -199,6 +199,9 @@ const translation = { title: 'Tingkatkan untuk membuka lebih banyak peristiwa pemicu', description: 'Anda telah mencapai batas pemicu acara alur kerja untuk paket ini.', }, + viewBillingTitle: 'Penagihan dan Langganan', + viewBillingDescription: 'Kelola metode pembayaran, faktur, dan perubahan langganan', + viewBillingAction: 'Kelola', } export default translation diff --git a/web/i18n/it-IT/billing.ts b/web/i18n/it-IT/billing.ts index 58c41058ad..141cefa21f 100644 --- a/web/i18n/it-IT/billing.ts +++ b/web/i18n/it-IT/billing.ts @@ -148,7 +148,7 @@ const translation = { description: 'Ottieni tutte le capacità e il supporto per sistemi mission-critical su larga scala.', includesTitle: 'Tutto nel piano Team, più:', - features: ['Soluzioni di Distribuzione Scalabili di Classe Aziendale', 'Autorizzazione alla Licenza Commerciale', 'Funzionalità Esclusive per le Aziende', 'Molteplici Spazi di Lavoro e Gestione Aziendale', 'SSO', 'SLA negoziati dai partner Dify', 'Sicurezza e Controlli Avanzati', 'Aggiornamenti e manutenzione ufficiali di Dify', 'Assistenza Tecnica Professionale'], + features: ['Soluzioni di Distribuzione Scalabili di Classe Aziendale', 'Autorizzazione alla Licenza Commerciale', 'Funzionalità Esclusive per le Aziende', 'Molteplici Spazi di Lavoro e Gestione Aziendale', 'SSO', 'SLA negoziati dai partner Dify', 'Sicurezza e Controlli Avanzati', 'Aggiornamenti e manutenzione da Dify ufficialmente', 'Assistenza Tecnica Professionale'], price: 'Personalizzato', for: 'Per team di grandi dimensioni', btnText: 'Contatta le vendite', @@ -164,7 +164,7 @@ const translation = { for: 'Per utenti individuali, piccole squadre o progetti non commerciali', }, premium: { - features: ['Affidabilità Autogestita dai Vari Provider Cloud', 'Spazio di lavoro singolo', 'Personalizzazione del Logo e del Marchio dell\'App Web', 'Assistenza Prioritaria via Email e Chat'], + features: ['Affidabilità autogestita dai vari provider cloud', 'Spazio di lavoro singolo', 'Personalizzazione del Logo e del Marchio dell\'App Web', 'Assistenza Prioritaria via Email e Chat'], name: 'Premium', priceTip: 'Basato su Cloud Marketplace', includesTitle: 'Tutto dalla Community, oltre a:', @@ -210,6 +210,9 @@ const translation = { title: 'Aggiorna per sbloccare più eventi di attivazione', description: 'Hai raggiunto il limite degli eventi di attivazione del flusso di lavoro per questo piano.', }, + viewBillingTitle: 'Fatturazione e Abbonamenti', + viewBillingDescription: 'Gestisci metodi di pagamento, fatture e modifiche all\'abbonamento', + viewBillingAction: 'Gestire', } export default translation diff --git a/web/i18n/ja-JP/billing.ts b/web/i18n/ja-JP/billing.ts index 57ccef5491..97fa4eb0e6 100644 --- a/web/i18n/ja-JP/billing.ts +++ b/web/i18n/ja-JP/billing.ts @@ -24,6 +24,9 @@ const translation = { encourageShort: 'アップグレード', }, viewBilling: '請求とサブスクリプションの管理', + viewBillingTitle: '請求とサブスクリプション', + viewBillingDescription: '支払い方法、請求書、サブスクリプションの変更の管理。', + viewBillingAction: '管理', buyPermissionDeniedTip: 'サブスクリプションするには、エンタープライズ管理者に連絡してください', plansCommon: { title: { @@ -158,7 +161,7 @@ const translation = { price: '無料', btnText: 'コミュニティ版を始めましょう', includesTitle: '無料機能:', - features: ['パブリックリポジトリで公開されているすべてのコア機能', '単一ワークスペース', 'Dify オープンソースライセンスに準拠'], + features: ['すべてのコア機能がパブリックリポジトリで公開されました', '単一ワークスペース', 'Difyオープンソースライセンスに準拠'], }, premium: { name: 'プレミアム', @@ -169,7 +172,7 @@ const translation = { btnText: 'プレミアム版を取得', includesTitle: 'コミュニティ版機能に加えて:', comingSoon: 'Microsoft Azure & Google Cloud 近日対応', - features: ['複数のクラウドプロバイダーでのセルフマネージド導入', '単一ワークスペース', 'Webアプリのロゴとブランディングをカスタマイズ', '優先メール/チャットサポート'], + features: ['さまざまなクラウドプロバイダーによる自己管理型の信頼性', '単一ワークスペース', 'Webアプリのロゴとブランドカスタマイズ', '優先メール&チャットサポート'], }, enterprise: { name: 'エンタープライズ', @@ -179,7 +182,7 @@ const translation = { priceTip: '年間契約専用', btnText: '営業に相談', includesTitle: 'プレミアム版機能に加えて:', - features: ['エンタープライズ向けのスケーラブルなデプロイソリューション', '商用ライセンス認可', 'エンタープライズ専用機能', '複数ワークスペースとエンタープライズ管理', 'シングルサインオン(SSO)', 'Dify パートナーによる交渉済み SLA', '高度なセキュリティと制御', 'Dify 公式による更新とメンテナンス', 'プロフェッショナルな技術サポート'], + features: ['エンタープライズ向けスケーラブルな展開ソリューション', '商用ライセンス認可', 'エンタープライズ専用機能', '複数ワークスペースとエンタープライズ管理', 'シングルサインオン', 'Difyパートナーによる交渉済みSLA', '高度なセキュリティと制御', 'Dify公式による更新とメンテナンス', 'プロフェッショナル技術サポート'], }, }, vectorSpace: { diff --git a/web/i18n/ko-KR/billing.ts b/web/i18n/ko-KR/billing.ts index 2ab06beace..c11d2bd11c 100644 --- a/web/i18n/ko-KR/billing.ts +++ b/web/i18n/ko-KR/billing.ts @@ -152,7 +152,7 @@ const translation = { btnText: '판매 문의하기', for: '대규모 팀을 위해', priceTip: '연간 청구 전용', - features: ['기업용 확장 가능한 배포 솔루션', '상업용 라이선스 승인', '독점 기업 기능', '여러 작업 공간 및 기업 관리', '싱글 사인온', 'Dify 파트너가 협상한 SLA', '고급 보안 및 제어', 'Dify 공식 업데이트 및 유지 관리', '전문 기술 지원'], + features: ['기업용 확장형 배포 솔루션', '상업용 라이선스 승인', '독점 기업 기능', '여러 작업 공간 및 기업 관리', '싱글 사인온', 'Dify 파트너가 협상한 SLA', '고급 보안 및 제어', 'Dify 공식 업데이트 및 유지보수', '전문 기술 지원'], }, community: { btnText: '커뮤니티 시작하기', @@ -172,7 +172,7 @@ const translation = { price: '확장 가능', for: '중규모 조직 및 팀을 위한', includesTitle: '커뮤니티의 모든 것, 여기에 추가로:', - features: ['다양한 클라우드 제공업체의 자체 관리 신뢰성', '단일 작업 공간', '웹앱 로고 및 브랜딩 맞춤 설정', '우선 이메일 및 채팅 지원'], + features: ['다양한 클라우드 제공업체에 의한 자체 관리 신뢰성', '단일 작업 공간', '웹앱 로고 및 브랜딩 맞춤 설정', '우선 이메일 및 채팅 지원'], }, }, vectorSpace: { @@ -212,6 +212,9 @@ const translation = { description: '이 요금제의 워크플로 이벤트 트리거 한도에 도달했습니다.', upgrade: '업그레이드', }, + viewBillingTitle: '청구 및 구독', + viewBillingDescription: '결제 수단, 청구서 및 구독 변경 관리', + viewBillingAction: '관리하다', } export default translation diff --git a/web/i18n/pl-PL/billing.ts b/web/i18n/pl-PL/billing.ts index 63047dcd15..8131f74b1e 100644 --- a/web/i18n/pl-PL/billing.ts +++ b/web/i18n/pl-PL/billing.ts @@ -163,7 +163,7 @@ const translation = { for: 'Dla użytkowników indywidualnych, małych zespołów lub projektów niekomercyjnych', }, premium: { - features: ['Niezawodność zarządzana samodzielnie przez różnych dostawców chmury', 'Pojedyncza przestrzeń robocza', 'Dostosowywanie logo i marki aplikacji webowej', 'Priorytetowe wsparcie e-mail i czat'], + features: ['Niezawodność zarządzana samodzielnie przez różnych dostawców chmury', 'Pojedyncza przestrzeń robocza', 'Dostosowywanie logo i identyfikacji wizualnej aplikacji webowej', 'Priorytetowe wsparcie e-mail i czat'], description: 'Dla średnich organizacji i zespołów', for: 'Dla średnich organizacji i zespołów', name: 'Premium', @@ -209,6 +209,9 @@ const translation = { title: 'Uaktualnij, aby odblokować więcej zdarzeń wyzwalających', dismiss: 'Odrzuć', }, + viewBillingTitle: 'Rozliczenia i subskrypcje', + viewBillingDescription: 'Zarządzaj metodami płatności, fakturami i zmianami subskrypcji', + viewBillingAction: 'Zarządzać', } export default translation diff --git a/web/i18n/pt-BR/billing.ts b/web/i18n/pt-BR/billing.ts index 5320dfbe8a..08976988e6 100644 --- a/web/i18n/pt-BR/billing.ts +++ b/web/i18n/pt-BR/billing.ts @@ -153,7 +153,7 @@ const translation = { for: 'Para Usuários Individuais, Pequenas Equipes ou Projetos Não Comerciais', }, premium: { - features: ['Confiabilidade Autogerenciada por Diversos Provedores de Nuvem', 'Espaço de Trabalho Único', 'Personalização de Logo e Marca do WebApp', 'Suporte Prioritário por E-mail e Chat'], + features: ['Confiabilidade Autogerenciada por Diversos Provedores de Nuvem', 'Espaço de Trabalho Único', 'Personalização de Logo e Marca do WebApp', 'Suporte Prioritário por Email e Chat'], includesTitle: 'Tudo da Comunidade, além de:', for: 'Para organizações e equipes de médio porte', price: 'Escalável', @@ -199,6 +199,9 @@ const translation = { upgrade: 'Atualizar', description: 'Você atingiu o limite de eventos de gatilho de fluxo de trabalho para este plano.', }, + viewBillingTitle: 'Faturamento e Assinaturas', + viewBillingDescription: 'Gerencie métodos de pagamento, faturas e alterações de assinatura', + viewBillingAction: 'Gerenciar', } export default translation diff --git a/web/i18n/ro-RO/billing.ts b/web/i18n/ro-RO/billing.ts index 96718014c3..79bee0b320 100644 --- a/web/i18n/ro-RO/billing.ts +++ b/web/i18n/ro-RO/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'Întreprindere', description: 'Obțineți capacități și asistență complete pentru sisteme critice la scară largă.', includesTitle: 'Tot ce este în planul Echipă, plus:', - features: ['Soluții de implementare scalabile la nivel de întreprindere', 'Autorizație de licență comercială', 'Funcții Exclusive pentru Afaceri', 'Mai multe spații de lucru și managementul întreprinderii', 'SSO', 'SLA-uri negociate de partenerii Dify', 'Securitate și Control Avansate', 'Actualizări și întreținere de către Dify Oficial', 'Asistență Tehnică Profesională'], + features: ['Soluții de implementare scalabile la nivel de întreprindere', 'Autorizație de licență comercială', 'Funcții Exclusive pentru Afaceri', 'Multiple spații de lucru și gestionarea întreprinderii', 'Autentificare unică', 'SLA-uri negociate de partenerii Dify', 'Securitate și Control Avansate', 'Actualizări și întreținere de către Dify Oficial', 'Asistență Tehnică Profesională'], for: 'Pentru echipe de mari dimensiuni', price: 'Personalizat', priceTip: 'Facturare anuală doar', @@ -153,7 +153,7 @@ const translation = { includesTitle: 'Funcții gratuite:', }, premium: { - features: ['Fiabilitate autogestionată de diferiți furnizori de cloud', 'Spațiu de lucru unic', 'Personalizare logo și branding pentru aplicația web', 'Asistență prioritară prin e-mail și chat'], + features: ['Fiabilitate autogestionată de diferiți furnizori de cloud', 'Spațiu de lucru unic', 'Personalizare Logo și Branding pentru WebApp', 'Asistență prioritară prin email și chat'], btnText: 'Obține Premium în', description: 'Pentru organizații și echipe de dimensiuni medii', includesTitle: 'Totul din Comunitate, plus:', @@ -199,6 +199,9 @@ const translation = { description: 'Ai atins limita de evenimente declanșatoare de flux de lucru pentru acest plan.', title: 'Actualizează pentru a debloca mai multe evenimente declanșatoare', }, + viewBillingTitle: 'Facturare și abonamente', + viewBillingDescription: 'Gestionează metodele de plată, facturile și modificările abonamentului', + viewBillingAction: 'Gestiona', } export default translation diff --git a/web/i18n/ru-RU/billing.ts b/web/i18n/ru-RU/billing.ts index 32ec234e02..8f5c6dcf38 100644 --- a/web/i18n/ru-RU/billing.ts +++ b/web/i18n/ru-RU/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'Корпоративный', description: 'Получите полный набор возможностей и поддержку для крупномасштабных критически важных систем.', includesTitle: 'Все в командном плане, плюс:', - features: ['Масштабируемые решения для развертывания корпоративного уровня', 'Разрешение на коммерческую лицензию', 'Эксклюзивные корпоративные функции', 'Несколько рабочих пространств и корпоративное управление', 'SSO', 'Согласованные SLA с партнёрами Dify', 'Расширенные функции безопасности и управления', 'Обновления и обслуживание от Dify официально', 'Профессиональная техническая поддержка'], + features: ['Масштабируемые решения для развертывания корпоративного уровня', 'Разрешение на коммерческую лицензию', 'Эксклюзивные корпоративные функции', 'Несколько рабочих пространств и корпоративное управление', 'Единый вход (SSO)', 'Договоренные SLA с партнёрами Dify', 'Расширенные функции безопасности и управления', 'Обновления и обслуживание от Dify официально', 'Профессиональная техническая поддержка'], price: 'Пользовательский', priceTip: 'Только годовая подписка', for: 'Для команд большого размера', @@ -199,6 +199,9 @@ const translation = { description: 'Вы достигли предела триггеров событий рабочего процесса для этого плана.', title: 'Обновите, чтобы открыть больше событий срабатывания', }, + viewBillingTitle: 'Платежи и подписки', + viewBillingDescription: 'Управляйте способами оплаты, счетами и изменениями подписки', + viewBillingAction: 'Управлять', } export default translation diff --git a/web/i18n/sl-SI/billing.ts b/web/i18n/sl-SI/billing.ts index 02c10e5b05..fa6b4b7bc6 100644 --- a/web/i18n/sl-SI/billing.ts +++ b/web/i18n/sl-SI/billing.ts @@ -144,7 +144,7 @@ const translation = { for: 'Za velike ekipe', }, community: { - features: ['Vse osnovne funkcije so izdane v javni repozitorij', 'Enotno delovno okolje', 'V skladu z Dify licenco odprte kode'], + features: ['Vse osnovne funkcije so izdane v javnem repozitoriju', 'Enotno delovno okolje', 'V skladu z Dify licenco odprte kode'], includesTitle: 'Brezplačne funkcije:', price: 'Brezplačno', name: 'Skupnost', @@ -199,6 +199,9 @@ const translation = { title: 'Nadgradite za odklep več sprožilnih dogodkov', upgrade: 'Nadgradnja', }, + viewBillingTitle: 'Fakturiranje in naročnine', + viewBillingDescription: 'Upravljajte načine plačila, račune in spremembe naročnin', + viewBillingAction: 'Upravljaj', } export default translation diff --git a/web/i18n/th-TH/billing.ts b/web/i18n/th-TH/billing.ts index c2df6e5745..ce8d3bbbef 100644 --- a/web/i18n/th-TH/billing.ts +++ b/web/i18n/th-TH/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'กิจการ', description: 'รับความสามารถและการสนับสนุนเต็มรูปแบบสําหรับระบบที่สําคัญต่อภารกิจขนาดใหญ่', includesTitle: 'ทุกอย่างในแผนทีม รวมถึง:', - features: ['โซลูชันการปรับใช้ที่ปรับขนาดได้สำหรับองค์กร', 'การอนุญาตใบอนุญาตเชิงพาณิชย์', 'ฟีเจอร์สำหรับองค์กรแบบพิเศษ', 'หลายพื้นที่ทำงานและการจัดการองค์กร', 'SSO', 'ข้อตกลงระดับการให้บริการที่เจรจาโดยพันธมิตร Dify', 'ระบบความปลอดภัยและการควบคุมขั้นสูง', 'การอัปเดตและการบำรุงรักษาโดย Dify อย่างเป็นทางการ', 'การสนับสนุนทางเทคนิคระดับมืออาชีพ'], + features: ['โซลูชันการปรับใช้ที่ปรับขนาดได้สำหรับองค์กร', 'การอนุญาตใบอนุญาตเชิงพาณิชย์', 'ฟีเจอร์สำหรับองค์กรแบบพิเศษ', 'หลายพื้นที่ทำงานและการจัดการองค์กร', 'ระบบลงชื่อเพียงครั้งเดียว', 'ข้อตกลงระดับการให้บริการที่เจรจาโดยพันธมิตร Dify', 'ระบบความปลอดภัยและการควบคุมขั้นสูง', 'การอัปเดตและการบำรุงรักษาโดย Dify อย่างเป็นทางการ', 'การสนับสนุนทางเทคนิคระดับมืออาชีพ'], btnText: 'ติดต่อฝ่ายขาย', price: 'ที่กำหนดเอง', for: 'สำหรับทีมขนาดใหญ่', @@ -153,7 +153,7 @@ const translation = { for: 'สำหรับผู้ใช้ส่วนบุคคล ทีมขนาดเล็ก หรือโครงการที่ไม่ใช่เชิงพาณิชย์', }, premium: { - features: ['ความน่าเชื่อถือที่บริหารเองโดยผู้ให้บริการคลาวด์หลายราย', 'พื้นที่ทำงานเดียว', 'การปรับแต่งโลโก้และแบรนด์ของเว็บแอป', 'บริการอีเมลและแชทด่วน'], + features: ['ความน่าเชื่อถือที่บริหารเองโดยผู้ให้บริการคลาวด์ต่างๆ', 'พื้นที่ทำงานเดียว', 'การปรับแต่งโลโก้และแบรนด์ของเว็บแอป', 'บริการอีเมลและแชทด่วน'], priceTip: 'อิงตามตลาดคลาวด์', for: 'สำหรับองค์กรและทีมขนาดกลาง', btnText: 'รับพรีเมียมใน', @@ -199,6 +199,9 @@ const translation = { title: 'อัปเกรดเพื่อปลดล็อกเหตุการณ์ทริกเกอร์เพิ่มเติม', description: 'คุณได้ถึงขีดจำกัดของทริกเกอร์เหตุการณ์เวิร์กโฟลว์สำหรับแผนนี้แล้ว', }, + viewBillingTitle: 'การเรียกเก็บเงินและการสมัครสมาชิก', + viewBillingDescription: 'จัดการวิธีการชำระเงิน ใบแจ้งหนี้ และการเปลี่ยนแปลงการสมัครสมาชิก', + viewBillingAction: 'จัดการ', } export default translation diff --git a/web/i18n/tr-TR/billing.ts b/web/i18n/tr-TR/billing.ts index 5a11471488..06570c9818 100644 --- a/web/i18n/tr-TR/billing.ts +++ b/web/i18n/tr-TR/billing.ts @@ -144,7 +144,7 @@ const translation = { price: 'Özel', }, community: { - features: ['Tüm Temel Özellikler Açık Depoda Yayınlandı', 'Tek Çalışma Alanı', 'Dify Açık Kaynak Lisansına uygundur'], + features: ['Tüm Temel Özellikler Açık Kaynak Depoda Yayınlandı', 'Tek Çalışma Alanı', 'Dify Açık Kaynak Lisansına uygundur'], price: 'Ücretsiz', includesTitle: 'Ücretsiz Özellikler:', name: 'Topluluk', @@ -153,7 +153,7 @@ const translation = { description: 'Bireysel Kullanıcılar, Küçük Ekipler veya Ticari Olmayan Projeler İçin', }, premium: { - features: ['Çeşitli Bulut Sağlayıcıları Tarafından Kendi Kendine Yönetilen Güvenilirlik', 'Tek Çalışma Alanı', 'Web Uygulama Logo ve Marka Özelleştirmesi', 'Öncelikli E-posta ve Sohbet Desteği'], + features: ['Çeşitli Bulut Sağlayıcıları Tarafından Kendi Kendine Yönetilen Güvenilirlik', 'Tek Çalışma Alanı', 'Web Uygulaması Logo ve Marka Özelleştirme', 'Öncelikli E-posta ve Sohbet Desteği'], name: 'Premium', includesTitle: 'Topluluktan her şey, artı:', for: 'Orta Büyüklükteki Organizasyonlar ve Ekipler için', @@ -199,6 +199,9 @@ const translation = { description: 'Bu plan için iş akışı etkinliği tetikleyici sınırına ulaştınız.', usageTitle: 'TETİKLEYİCİ OLAYLAR', }, + viewBillingTitle: 'Faturalama ve Abonelikler', + viewBillingDescription: 'Ödeme yöntemlerini, faturaları ve abonelik değişikliklerini yönetin', + viewBillingAction: 'Yönet', } export default translation diff --git a/web/i18n/uk-UA/billing.ts b/web/i18n/uk-UA/billing.ts index ffe6778d19..5502afd608 100644 --- a/web/i18n/uk-UA/billing.ts +++ b/web/i18n/uk-UA/billing.ts @@ -137,14 +137,14 @@ const translation = { name: 'Ентерпрайз', description: 'Отримайте повні можливості та підтримку для масштабних критично важливих систем.', includesTitle: 'Все, що входить до плану Team, плюс:', - features: ['Масштабовані рішення для розгортання корпоративного рівня', 'Авторизація комерційної ліцензії', 'Ексклюзивні корпоративні функції', 'Кілька робочих просторів та управління підприємством', 'SSO', 'Укладені SLA партнерами Dify', 'Розширена безпека та контроль', 'Оновлення та обслуговування від Dify офіційно', 'Професійна технічна підтримка'], + features: ['Масштабовані рішення для розгортання корпоративного рівня', 'Авторизація комерційної ліцензії', 'Ексклюзивні корпоративні функції', 'Кілька робочих просторів та управління підприємством', 'ЄДИНА СИСТЕМА АВТОРИЗАЦІЇ', 'Узгоджені SLA партнерами Dify', 'Розширена безпека та контроль', 'Оновлення та обслуговування від Dify офіційно', 'Професійна технічна підтримка'], btnText: 'Зв\'язатися з відділом продажу', priceTip: 'Тільки річна оплата', for: 'Для великих команд', price: 'Користувацький', }, community: { - features: ['Всі основні функції опубліковані у публічному репозиторії', 'Один робочий простір', 'Відповідає ліцензії відкритого програмного забезпечення Dify'], + features: ['Всі основні функції випущені у публічному репозиторії', 'Один робочий простір', 'Відповідає ліцензії відкритого програмного забезпечення Dify'], btnText: 'Розпочніть з громади', includesTitle: 'Безкоштовні можливості:', for: 'Для індивідуальних користувачів, малих команд або некомерційних проектів', @@ -199,6 +199,9 @@ const translation = { title: 'Оновіть, щоб розблокувати більше подій-тригерів', description: 'Ви досягли ліміту тригерів подій робочого процесу для цього плану.', }, + viewBillingTitle: 'Білінг та підписки', + viewBillingDescription: 'Керуйте способами оплати, рахунками та змінами підписки', + viewBillingAction: 'Керувати', } export default translation diff --git a/web/i18n/vi-VN/billing.ts b/web/i18n/vi-VN/billing.ts index c95ce4c992..04e8385347 100644 --- a/web/i18n/vi-VN/billing.ts +++ b/web/i18n/vi-VN/billing.ts @@ -153,7 +153,7 @@ const translation = { includesTitle: 'Tính năng miễn phí:', }, premium: { - features: ['Độ tin cậy tự quản lý bởi các nhà cung cấp đám mây khác nhau', 'Không gian làm việc đơn', 'Tùy chỉnh Logo & Thương hiệu WebApp', 'Hỗ trợ Email & Trò chuyện Ưu tiên'], + features: ['Độ tin cậy tự quản lý bởi các nhà cung cấp đám mây khác nhau', 'Không gian làm việc đơn', 'Tùy Chỉnh Logo & Thương Hiệu Ứng Dụng Web', 'Hỗ trợ Email & Trò chuyện Ưu tiên'], comingSoon: 'Hỗ trợ Microsoft Azure & Google Cloud Sẽ Đến Sớm', priceTip: 'Dựa trên Thị trường Đám mây', btnText: 'Nhận Premium trong', @@ -199,6 +199,9 @@ const translation = { description: 'Bạn đã đạt đến giới hạn kích hoạt sự kiện quy trình cho gói này.', title: 'Nâng cấp để mở khóa thêm nhiều sự kiện kích hoạt', }, + viewBillingTitle: 'Thanh toán và Đăng ký', + viewBillingDescription: 'Quản lý phương thức thanh toán, hóa đơn và thay đổi đăng ký', + viewBillingAction: 'Quản lý', } export default translation diff --git a/web/i18n/zh-Hans/billing.ts b/web/i18n/zh-Hans/billing.ts index ee8f7af64d..b404240b3d 100644 --- a/web/i18n/zh-Hans/billing.ts +++ b/web/i18n/zh-Hans/billing.ts @@ -24,6 +24,9 @@ const translation = { encourageShort: '升级', }, viewBilling: '管理账单及订阅', + viewBillingTitle: '账单与订阅', + viewBillingDescription: '管理支付方式、发票和订阅变更。', + viewBillingAction: '管理', buyPermissionDeniedTip: '请联系企业管理员订阅', plansCommon: { title: { @@ -158,11 +161,7 @@ const translation = { price: '免费', btnText: '开始使用', includesTitle: '免费功能:', - features: [ - '所有核心功能均在公共存储库下发布', - '单一工作空间', - '符合 Dify 开源许可证', - ], + features: ['所有核心功能已在公共仓库发布', '单一工作区', '遵守 Dify 开源许可证'], }, premium: { name: 'Premium', @@ -173,12 +172,7 @@ const translation = { btnText: '获得 Premium 版', includesTitle: 'Community 版的所有功能,加上:', comingSoon: '即将支持 Microsoft Azure & Google Cloud', - features: [ - '各个云提供商自行管理的可靠性', - '单一工作空间', - '自定义 WebApp & 品牌', - '优先电子邮件 & 聊天支持', - ], + features: ['由各云服务提供商自主管理的可靠性', '单一工作区', 'WebApp 徽标与品牌定制', '优先电子邮件和聊天支持'], }, enterprise: { name: 'Enterprise', @@ -188,17 +182,7 @@ const translation = { priceTip: '仅按年计费', btnText: '联系销售', includesTitle: 'Premium 版的所有功能,加上:', - features: [ - '企业级可扩展部署解决方案', - '商业许可授权', - '专属企业级功能', - '多个工作空间 & 企业级管理', - 'SSO', - '由 Dify 合作伙伴支持的可协商的 SLAs', - '高级的安全 & 控制', - '由 Dify 官方提供的更新 & 维护', - '专业技术支持', - ], + features: ['企业级可扩展部署解决方案', '商业许可授权', '专属企业功能', '多个工作区与企业管理', '单点登录', '由 Dify 合作伙伴协商的服务水平协议', '高级安全与控制', '由 Dify 官方进行的更新和维护', '专业技术支持'], }, }, vectorSpace: { diff --git a/web/i18n/zh-Hant/billing.ts b/web/i18n/zh-Hant/billing.ts index f693f9ae35..42534756a5 100644 --- a/web/i18n/zh-Hant/billing.ts +++ b/web/i18n/zh-Hant/billing.ts @@ -137,7 +137,7 @@ const translation = { name: 'Enterprise', description: '獲得大規模關鍵任務系統的完整功能和支援。', includesTitle: 'Team 計劃中的一切,加上:', - features: ['企業級可擴展部署解決方案', '商業許可授權', '專屬企業功能', '多工作區與企業管理', '單一登入', '由 Dify 合作夥伴協商的服務水平協議', '進階安全與控制', 'Dify 官方的更新與維護', '專業技術支援'], + features: ['企業級可擴展部署解決方案', '商業許可授權', '企業專屬功能', '多工作區與企業管理', '單一登入', '由 Dify 合作夥伴協商的服務水平協議', '進階安全與控管', 'Dify 官方更新與維護', '專業技術支援'], price: '自訂', btnText: '聯繫銷售', priceTip: '年度計費のみ', @@ -199,6 +199,9 @@ const translation = { title: '升級以解鎖更多觸發事件', upgrade: '升級', }, + viewBillingTitle: '帳單與訂閱', + viewBillingDescription: '管理付款方式、發票和訂閱變更', + viewBillingAction: '管理', } export default translation diff --git a/web/service/billing.ts b/web/service/billing.ts index 7dfe5ac0a7..979a888582 100644 --- a/web/service/billing.ts +++ b/web/service/billing.ts @@ -1,4 +1,4 @@ -import { get } from './base' +import { get, put } from './base' import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type' export const fetchCurrentPlanInfo = () => { @@ -12,3 +12,13 @@ export const fetchSubscriptionUrls = (plan: string, interval: string) => { export const fetchBillingUrl = () => { return get<{ url: string }>('/billing/invoices') } + +export const bindPartnerStackInfo = (partnerKey: string, clickId: string) => { + return put(`/billing/partners/${partnerKey}/tenants`, { + body: { + click_id: clickId, + }, + }, { + silent: true, + }) +} diff --git a/web/service/use-billing.ts b/web/service/use-billing.ts index b48a75eab0..2701861bc0 100644 --- a/web/service/use-billing.ts +++ b/web/service/use-billing.ts @@ -1,19 +1,22 @@ -import { useMutation } from '@tanstack/react-query' -import { put } from './base' +import { useMutation, useQuery } from '@tanstack/react-query' +import { bindPartnerStackInfo, fetchBillingUrl } from '@/service/billing' const NAME_SPACE = 'billing' export const useBindPartnerStackInfo = () => { return useMutation({ mutationKey: [NAME_SPACE, 'bind-partner-stack'], - mutationFn: (data: { partnerKey: string; clickId: string }) => { - return put(`/billing/partners/${data.partnerKey}/tenants`, { - body: { - click_id: data.clickId, - }, - }, { - silent: true, - }) + mutationFn: (data: { partnerKey: string; clickId: string }) => bindPartnerStackInfo(data.partnerKey, data.clickId), + }) +} + +export const useBillingUrl = (enabled: boolean) => { + return useQuery({ + queryKey: [NAME_SPACE, 'url'], + enabled, + queryFn: async () => { + const res = await fetchBillingUrl() + return res.url }, }) }