From f15a8f02ef022db28b9555fb26c74d01c4ef1a1c Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Jun 2026 13:53:12 +0900 Subject: [PATCH] ci: add flag for linter (#37018) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/ops/ops_trace_manager.py | 3 ++- api/core/rag/embedding/cached_embedding.py | 6 ++++- .../processor/paragraph_index_processor.py | 10 +++++++- .../processor/parent_child_index_processor.py | 10 +++++++- .../processor/qa_index_processor.py | 10 +++++++- api/core/rag/rerank/rerank_model.py | 2 ++ api/core/rag/rerank/weight_rerank.py | 2 ++ api/core/rag/splitter/fixed_text_splitter.py | 3 ++- api/core/rag/splitter/text_splitter.py | 6 ++++- api/core/tools/builtin_tool/provider.py | 4 +++- .../builtin_tool/providers/audio/audio.py | 3 ++- .../builtin_tool/providers/audio/tools/asr.py | 4 +++- .../builtin_tool/providers/audio/tools/tts.py | 4 +++- .../tools/builtin_tool/providers/code/code.py | 3 ++- .../providers/code/tools/simple_code.py | 3 ++- .../tools/builtin_tool/providers/time/time.py | 3 ++- .../providers/time/tools/current_time.py | 3 ++- .../time/tools/localtime_to_timestamp.py | 3 ++- .../time/tools/timestamp_to_localtime.py | 3 ++- .../time/tools/timezone_conversion.py | 3 ++- .../providers/time/tools/weekday.py | 3 ++- .../providers/webscraper/tools/webscraper.py | 3 ++- .../providers/webscraper/webscraper.py | 3 ++- api/core/tools/builtin_tool/tool.py | 4 ++++ api/core/tools/custom_tool/provider.py | 4 ++++ api/core/tools/custom_tool/tool.py | 5 +++- api/core/tools/mcp_tool/provider.py | 4 +++- api/core/tools/mcp_tool/tool.py | 5 +++- api/core/tools/plugin_tool/provider.py | 4 +++- api/core/tools/plugin_tool/tool.py | 6 ++++- .../dataset_multi_retriever_tool.py | 2 ++ .../dataset_retriever_tool.py | 3 ++- .../tools/utils/dataset_retriever_tool.py | 5 +++- api/core/tools/workflow_as_tool/provider.py | 2 ++ api/core/tools/workflow_as_tool/tool.py | 5 +++- api/core/trigger/debug/event_selectors.py | 5 +++- api/core/trigger/utils/encryption.py | 5 +++- api/core/workflow/human_input_adapter.py | 3 ++- api/core/workflow/node_factory.py | 5 ++++ api/core/workflow/node_runtime.py | 24 ++++++++++++++++++- api/core/workflow/nodes/agent/agent_node.py | 6 ++++- .../nodes/agent/plugin_strategy_adapter.py | 4 ++++ .../workflow/nodes/agent_v2/agent_node.py | 6 ++++- .../nodes/datasource/datasource_node.py | 6 ++++- .../knowledge_index/knowledge_index_node.py | 4 +++- .../knowledge_retrieval_node.py | 5 +++- .../trigger_plugin/trigger_event_node.py | 6 ++++- .../trigger_schedule/trigger_schedule_node.py | 4 ++++ .../workflow/nodes/trigger_webhook/node.py | 6 ++++- api/core/workflow/template_rendering.py | 3 ++- api/enterprise/telemetry/id_generator.py | 4 ++++ api/fields/app_fields.py | 2 ++ api/fields/raws.py | 3 +++ api/fields/workflow_fields.py | 3 +++ api/models/base.py | 3 +++ api/models/dataset.py | 3 ++- api/models/enums.py | 2 ++ api/models/model.py | 4 +++- api/models/provider.py | 2 ++ api/models/provider_ids.py | 3 +++ .../src/dify_trace_aliyun/aliyun_trace.py | 2 ++ .../arize_phoenix_trace.py | 3 ++- .../test_arize_phoenix_trace.py | 6 ++++- .../src/dify_trace_langfuse/langfuse_trace.py | 2 ++ .../langfuse_trace/test_langfuse_trace.py | 2 ++ .../dify_trace_langsmith/langsmith_trace.py | 3 ++- .../langsmith_trace/test_langsmith_trace.py | 2 ++ .../src/dify_trace_mlflow/mlflow_trace.py | 3 ++- .../src/dify_trace_opik/opik_trace.py | 3 ++- .../unit_tests/opik_trace/test_opik_trace.py | 2 ++ .../src/dify_trace_tencent/tencent_trace.py | 3 +++ .../src/dify_trace_weave/weave_trace.py | 3 ++- .../alibabacloud_mysql_vector.py | 12 +++++++++- .../test_alibabacloud_mysql_vector.py | 2 ++ .../dify_vdb_analyticdb/analyticdb_vector.py | 12 +++++++++- .../integration_tests/test_analyticdb.py | 3 +++ .../src/dify_vdb_baidu/baidu_vector.py | 12 +++++++++- .../tests/integration_tests/test_baidu.py | 4 ++++ .../src/dify_vdb_chroma/chroma_vector.py | 12 +++++++++- .../tests/integration_tests/test_chroma.py | 3 +++ .../dify_vdb_clickzetta/clickzetta_vector.py | 12 +++++++++- .../dify_vdb_couchbase/couchbase_vector.py | 12 +++++++++- .../tests/integration_tests/test_couchbase.py | 2 ++ .../elasticsearch_ja_vector.py | 4 +++- .../elasticsearch_vector.py | 12 +++++++++- .../src/dify_vdb_hologres/hologres_vector.py | 13 +++++++++- .../tests/integration_tests/test_hologres.py | 5 +++- .../huawei_cloud_vector.py | 12 +++++++++- .../integration_tests/test_huawei_cloud.py | 4 ++++ .../vdb-iris/src/dify_vdb_iris/iris_vector.py | 12 +++++++++- .../src/dify_vdb_lindorm/lindorm_vector.py | 13 +++++++++- .../tests/integration_tests/test_lindorm.py | 3 +++ .../dify_vdb_matrixone/matrixone_vector.py | 15 +++++++++++- .../tests/integration_tests/test_matrixone.py | 3 +++ .../src/dify_vdb_milvus/milvus_vector.py | 13 +++++++++- .../tests/integration_tests/test_milvus.py | 4 ++++ .../src/dify_vdb_myscale/myscale_vector.py | 13 +++++++++- .../tests/integration_tests/test_myscale.py | 3 +++ .../dify_vdb_oceanbase/oceanbase_vector.py | 13 +++++++++- .../tests/integration_tests/test_oceanbase.py | 3 +++ .../src/dify_vdb_opengauss/opengauss.py | 12 +++++++++- .../dify_vdb_opensearch/opensearch_vector.py | 13 +++++++++- .../src/dify_vdb_oracle/oraclevector.py | 12 +++++++++- .../integration_tests/test_oraclevector.py | 3 +++ .../src/dify_vdb_pgvecto_rs/pgvecto_rs.py | 21 ++++++++++++---- .../integration_tests/test_pgvecto_rs.py | 4 ++++ .../src/dify_vdb_pgvector/pgvector.py | 12 +++++++++- .../src/dify_vdb_qdrant/qdrant_vector.py | 12 +++++++++- .../tests/integration_tests/test_qdrant.py | 3 +++ .../tests/unit_tests/test_qdrant_vector.py | 2 ++ .../src/dify_vdb_relyt/relyt_vector.py | 13 +++++++++- .../dify_vdb_tablestore/tablestore_vector.py | 13 +++++++++- .../integration_tests/test_tablestore.py | 6 +++++ .../src/dify_vdb_tencent/tencent_vector.py | 12 +++++++++- .../tests/integration_tests/test_tencent.py | 3 +++ .../tidb_on_qdrant_vector.py | 12 +++++++++- .../src/dify_vdb_tidb_vector/tidb_vector.py | 13 +++++++++- .../integration_tests/test_tidb_vector.py | 4 ++++ .../src/dify_vdb_upstash/upstash_vector.py | 13 +++++++++- .../integration_tests/test_upstash_vector.py | 4 ++++ .../src/dify_vdb_vastbase/vastbase_vector.py | 12 +++++++++- .../src/dify_vdb_vikingdb/vikingdb_vector.py | 13 +++++++++- .../tests/integration_tests/test_vikingdb.py | 5 ++++ .../src/dify_vdb_weaviate/weaviate_vector.py | 13 +++++++++- .../tests/unit_tests/test_weaviate_vector.py | 4 ++++ api/pyproject.toml | 1 + api/services/api_token_service.py | 3 ++- api/services/app_service.py | 3 ++- api/services/auth/firecrawl/firecrawl.py | 2 ++ api/services/auth/jina.py | 2 ++ api/services/auth/jina/jina.py | 2 ++ api/services/auth/watercrawl/watercrawl.py | 2 ++ .../batch_indexing_base.py | 4 +++- .../document_indexing_task_proxy.py | 4 ++-- .../duplicate_document_indexing_task_proxy.py | 4 ++-- .../enterprise/plugin_manager_service.py | 2 ++ api/services/errors/llm.py | 4 ++++ api/services/legacy_model_type_migration.py | 5 +++- .../built_in/built_in_retrieval.py | 5 +++- .../customized/customized_retrieval.py | 5 +++- .../database/database_retrieval.py | 5 +++- .../remote/remote_retrieval.py | 5 +++- .../buildin/buildin_retrieval.py | 5 +++- .../database/database_retrieval.py | 5 +++- .../recommend_app/remote/remote_retrieval.py | 5 +++- .../conversation/messages_clean_policy.py | 3 +++ api/services/variable_truncator.py | 6 ++++- api/services/workflow/queue_dispatcher.py | 8 +++++++ .../workflow_collaboration_service.py | 2 ++ .../workflow_draft_variable_service.py | 3 ++- api/tasks/mail_inner_task.py | 3 ++- .../workflow_cfs_scheduler/cfs_scheduler.py | 3 +++ .../test_workflow_draft_variable_service.py | 8 ++++++- .../clients/agent_backend/test_client.py | 5 ++++ .../service_api/app/test_hitl_service_api.py | 7 ++++++ .../app/apps/agent_app/test_app_runner.py | 4 +++- .../core/external_data_tool/test_base.py | 19 +++++++++++---- .../datasource/keyword/test_keyword_base.py | 13 ++++++++++ .../rag/datasource/vdb/test_vector_base.py | 10 ++++++++ .../core/rag/embedding/test_embedding_base.py | 6 ++++- .../core/rag/extractor/blob/test_blob.py | 11 +++++---- .../core/rag/extractor/test_csv_extractor.py | 11 +++++---- .../core/rag/extractor/test_html_extractor.py | 6 +++-- .../rag/extractor/test_markdown_extractor.py | 2 +- .../core/rag/extractor/test_text_extractor.py | 2 +- .../core/rag/extractor/test_word_extractor.py | 4 +++- .../rag/indexing/test_index_processor_base.py | 12 +++++++++- .../unit_tests/core/schemas/test_registry.py | 13 +++++----- .../storage/test_supabase_storage.py | 3 ++- .../recommend_app/test_buildin_retrieval.py | 3 ++- .../test_workflow_event_snapshot_service.py | 14 ++++++++++- 171 files changed, 885 insertions(+), 135 deletions(-) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index c9bee61541..0066d035e5 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -7,7 +7,7 @@ import threading import time from collections.abc import Mapping from datetime import timedelta -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict, override from uuid import UUID, uuid4 from cachetools import LRUCache @@ -221,6 +221,7 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): + @override def __getitem__(self, key: str) -> TracingProviderConfigEntry: try: match key: diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index a9995778f7..cbf37fc754 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,7 +1,7 @@ import base64 import logging import pickle -from typing import Any, cast +from typing import Any, cast, override import numpy as np from sqlalchemy import select @@ -25,6 +25,7 @@ class CacheEmbedding(Embeddings): def __init__(self, model_instance: ModelInstance): self._model_instance = model_instance + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists @@ -106,6 +107,7 @@ class CacheEmbedding(Embeddings): return text_embeddings + @override def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" # use doc embedding cache or store if not exists @@ -189,6 +191,7 @@ class CacheEmbedding(Embeddings): return multimodel_embeddings + @override def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists @@ -232,6 +235,7 @@ class CacheEmbedding(Embeddings): return embedding_results # type: ignore + @override def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: """Embed multimodal documents.""" # use doc embedding cache or store if not exists diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 7ffa9afafd..7c7e8ab09d 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -3,7 +3,7 @@ import logging import re import uuid -from typing import Any, TypedDict, cast +from typing import Any, TypedDict, cast, override logger = logging.getLogger(__name__) @@ -61,6 +61,7 @@ class ParagraphFormatPreviewDict(TypedDict): class ParagraphIndexProcessor(BaseIndexProcessor): + @override def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( extract_setting=extract_setting, @@ -71,6 +72,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return text_docs + @override def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: @@ -120,6 +122,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents + @override def load( self, dataset: Dataset, @@ -142,6 +145,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.add_texts(documents) + @override def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: # Note: Summary indexes are now disabled (not deleted) when segments are disabled. # This method is called for actual deletion scenarios (e.g., when segment is deleted). @@ -178,6 +182,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.delete() + @override def retrieve( self, retrieval_method: RetrievalMethod, @@ -206,6 +211,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs + @override def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: documents: list[Any] = [] all_multimodal_documents: list[Any] = [] @@ -271,6 +277,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) + @override def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict: if isinstance(chunks, list): preview = [] @@ -285,6 +292,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: raise ValueError("Chunks is not a list") + @override def generate_summary_preview( self, tenant_id: str, diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index a26a900512..bf9145def1 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -3,7 +3,7 @@ import json import logging import uuid -from typing import Any, TypedDict +from typing import Any, TypedDict, override from sqlalchemy import delete, select @@ -44,6 +44,7 @@ class ParentChildFormatPreviewDict(TypedDict): class ParentChildIndexProcessor(BaseIndexProcessor): + @override def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( extract_setting=extract_setting, @@ -54,6 +55,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return text_docs + @override def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: @@ -129,6 +131,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return all_documents + @override def load( self, dataset: Dataset, @@ -149,6 +152,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if multimodal_documents and dataset.is_multimodal: vector.create_multimodal(multimodal_documents) + @override def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: # node_ids is segment's node_ids # Note: Summary indexes are now disabled (not deleted) when segments are disabled. @@ -219,6 +223,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): ) db.session.commit() + @override def retrieve( self, retrieval_method: RetrievalMethod, @@ -283,6 +288,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_nodes.append(child_document) return child_nodes + @override def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: parent_childs = ParentChildStructureChunk.model_validate(chunks) documents = [] @@ -356,6 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) + @override def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict: parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] @@ -369,6 +376,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } return result + @override def generate_summary_preview( self, tenant_id: str, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index d3f311b08e..7d1e7333a8 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,7 @@ import logging import re import threading import uuid -from typing import Any, TypedDict +from typing import Any, TypedDict, override import pandas as pd from flask import Flask, current_app @@ -43,6 +43,7 @@ class QAFormatPreviewDict(TypedDict): class QAIndexProcessor(BaseIndexProcessor): + @override def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( extract_setting=extract_setting, @@ -52,6 +53,7 @@ class QAIndexProcessor(BaseIndexProcessor): ) return text_docs + @override def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") @@ -139,6 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs + @override def load( self, dataset: Dataset, @@ -153,6 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor): if multimodal_documents and dataset.is_multimodal: vector.create_multimodal(multimodal_documents) + @override def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: # Note: Summary indexes are now disabled (not deleted) when segments are disabled. # This method is called for actual deletion scenarios (e.g., when segment is deleted). @@ -183,6 +187,7 @@ class QAIndexProcessor(BaseIndexProcessor): else: vector.delete() + @override def retrieve( self, retrieval_method: RetrievalMethod, @@ -211,6 +216,7 @@ class QAIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs + @override def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: qa_chunks = QAStructureChunk.model_validate(chunks) documents = [] @@ -234,6 +240,7 @@ class QAIndexProcessor(BaseIndexProcessor): else: raise ValueError("Indexing technique must be high quality.") + @override def format_preview(self, chunks: Any) -> QAFormatPreviewDict: qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] @@ -246,6 +253,7 @@ class QAIndexProcessor(BaseIndexProcessor): } return result + @override def generate_summary_preview( self, tenant_id: str, diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index bce08f998f..8552e7f65d 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,4 +1,5 @@ import base64 +from typing import override from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType @@ -16,6 +17,7 @@ class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance): self.rerank_model_instance = rerank_model_instance + @override def run( self, query: str, diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d0732b269a..3743d98fb9 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -1,5 +1,6 @@ import math from collections import Counter +from typing import override import numpy as np @@ -19,6 +20,7 @@ class WeightRerankRunner(BaseRerankRunner): self.tenant_id = tenant_id self.weights = weights + @override def run( self, query: str, diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 52c9a02f97..98ba4d7fcc 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -5,7 +5,7 @@ from __future__ import annotations import codecs import re from collections.abc import Set as AbstractSet -from typing import Any, Literal +from typing import Any, Literal, override from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter @@ -51,6 +51,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape") self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] + @override def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" if self._fixed_separator: diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index a8d9013fbc..39e5482269 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from collections.abc import Set as AbstractSet from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Literal, override from core.rag.models.document import BaseDocumentTransformer, Document @@ -148,10 +148,12 @@ class TextSplitter(BaseDocumentTransformer, ABC): ) return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) + @override def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) + @override async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError @@ -211,6 +213,7 @@ class TokenTextSplitter(TextSplitter): self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special + @override def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: return self._tokenizer.encode( @@ -287,5 +290,6 @@ class RecursiveCharacterTextSplitter(TextSplitter): return final_chunks + @override def split_text(self, text: str) -> list[str]: return self._split_text(text, self._separators) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 20cdb3e57f..52d86f0648 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -1,6 +1,6 @@ from abc import abstractmethod from os import listdir, path -from typing import Any +from typing import Any, override from core.entities.provider_entities import ProviderConfig from core.helper.module_import_helper import load_single_subclass_from_source @@ -105,6 +105,7 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.tools + @override def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider @@ -182,6 +183,7 @@ class BuiltinToolProviderController(ToolProviderController): ) @property + @override def provider_type(self) -> ToolProviderType: """ returns the type of the provider diff --git a/api/core/tools/builtin_tool/providers/audio/audio.py b/api/core/tools/builtin_tool/providers/audio/audio.py index abf23559ec..39b29e00e3 100644 --- a/api/core/tools/builtin_tool/providers/audio/audio.py +++ b/api/core/tools/builtin_tool/providers/audio/audio.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, override from core.tools.builtin_tool.provider import BuiltinToolProviderController class AudioToolProvider(BuiltinToolProviderController): + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): pass diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 95660ab93b..0c3047244c 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any +from typing import Any, override from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption @@ -14,6 +14,7 @@ from services.model_provider_service import ModelProviderService class ASRTool(BuiltinTool): + @override def _invoke( self, user_id: str, @@ -56,6 +57,7 @@ class ASRTool(BuiltinTool): items.append((provider, model.model)) return items + @override def get_runtime_parameters( self, conversation_id: str | None = None, diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index ac3820f1ab..db65391610 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -1,6 +1,6 @@ import io from collections.abc import Generator -from typing import Any +from typing import Any, override from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption @@ -12,6 +12,7 @@ from services.model_provider_service import ModelProviderService class TTSTool(BuiltinTool): + @override def _invoke( self, user_id: str, @@ -66,6 +67,7 @@ class TTSTool(BuiltinTool): items.append((provider, model.model, voices)) return items + @override def get_runtime_parameters( self, conversation_id: str | None = None, diff --git a/api/core/tools/builtin_tool/providers/code/code.py b/api/core/tools/builtin_tool/providers/code/code.py index 3e02a64e89..f34b63c370 100644 --- a/api/core/tools/builtin_tool/providers/code/code.py +++ b/api/core/tools/builtin_tool/providers/code/code.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, override from core.tools.builtin_tool.provider import BuiltinToolProviderController class CodeToolProvider(BuiltinToolProviderController): + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): pass diff --git a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py index 4383943199..dd041ef5eb 100644 --- a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py +++ b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, override from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage from core.tools.builtin_tool.tool import BuiltinTool @@ -8,6 +8,7 @@ from core.tools.errors import ToolInvokeError class SimpleCode(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/time/time.py b/api/core/tools/builtin_tool/providers/time/time.py index c8f33ec56b..d22b06c4eb 100644 --- a/api/core/tools/builtin_tool/providers/time/time.py +++ b/api/core/tools/builtin_tool/providers/time/time.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, override from core.tools.builtin_tool.provider import BuiltinToolProviderController class WikiPediaProvider(BuiltinToolProviderController): + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): pass diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index e07ca0d919..1164cc11c6 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import UTC, datetime -from typing import Any +from typing import Any, override from pytz import timezone as pytz_timezone # type: ignore[import-untyped] @@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage class CurrentTimeTool(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index dc49b64dd8..1ebb7ab3a7 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, override import pytz # type: ignore[import-untyped] @@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError class LocaltimeToTimestampTool(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 8045e4b980..f1bf147144 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, override import pytz # type: ignore[import-untyped] @@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError class TimestampToLocaltimeTool(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index e2570811d6..6f00e0009e 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, override import pytz # type: ignore[import-untyped] @@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError class TimezoneConversionTool(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/time/tools/weekday.py b/api/core/tools/builtin_tool/providers/time/tools/weekday.py index e26b316bd5..6a88538afe 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/weekday.py +++ b/api/core/tools/builtin_tool/providers/time/tools/weekday.py @@ -1,13 +1,14 @@ import calendar from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, override from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage class WeekdayTool(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py index 9d668ac9eb..44c6f720d0 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, override from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage @@ -8,6 +8,7 @@ from core.tools.utils.web_reader_tool import get_url class WebscraperTool(BuiltinTool): + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py index 7d8942d420..a74ae148cd 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py @@ -1,9 +1,10 @@ -from typing import Any +from typing import Any, override from core.tools.builtin_tool.provider import BuiltinToolProviderController class WebscraperProvider(BuiltinToolProviderController): + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ Validate credentials diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 1872cb46a9..0782856fea 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import override + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType @@ -26,6 +28,7 @@ class BuiltinTool(Tool): super().__init__(**kwargs) self.provider = provider + @override def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool: """ fork a new tool with metadata @@ -56,6 +59,7 @@ class BuiltinTool(Tool): caller_user_id=self.runtime.user_id, ) + @override def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index e2f6c00555..520a55dbd3 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import override + from pydantic import Field from sqlalchemy import select @@ -122,6 +124,7 @@ class ApiToolProviderController(ToolProviderController): ) @property + @override def provider_type(self) -> ToolProviderType: return ToolProviderType.API @@ -194,6 +197,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = tools return tools + @override def get_tool(self, tool_name: str) -> ApiTool: """ get tool by name diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 168e5f4493..2e618b7ea5 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -2,7 +2,7 @@ import json from collections.abc import Generator from dataclasses import dataclass from os import getenv -from typing import Any, Union +from typing import Any, Union, override from urllib.parse import urlencode import httpx @@ -45,6 +45,7 @@ class ApiTool(Tool): self.api_bundle = api_bundle self.provider_id = provider_id + @override def fork_tool_runtime(self, runtime: ToolRuntime): """ fork a new tool with metadata @@ -77,6 +78,7 @@ class ApiTool(Tool): # For credential validation, always return as string return parsed_response.to_string() + @override def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.API @@ -373,6 +375,7 @@ class ApiTool(Tool): except ValueError: return value + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index f46eeff6c5..52414153b8 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,4 +1,4 @@ -from typing import Any, Self +from typing import Any, Self, override from core.entities.mcp_provider import IdentityMode, MCPProviderEntity from core.mcp.types import Tool as RemoteMCPTool @@ -41,6 +41,7 @@ class MCPToolProviderController(ToolProviderController): self.identity_mode: IdentityMode = identity_mode @property + @override def provider_type(self) -> ToolProviderType: """ returns the type of the provider @@ -116,6 +117,7 @@ class MCPToolProviderController(ToolProviderController): """ pass + @override def get_tool(self, tool_name: str) -> MCPTool: """ return tool with given name diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 4629c138c0..b0f1f7a5f2 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -4,7 +4,7 @@ import base64 import json import logging from collections.abc import Generator, Mapping -from typing import Any, cast +from typing import Any, cast, override from configs import dify_config from core.entities.mcp_provider import IdentityMode @@ -58,9 +58,11 @@ class MCPTool(Tool): self.identity_mode: IdentityMode = identity_mode self._latest_usage = LLMUsage.empty_usage() + @override def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.MCP + @override def _invoke( self, user_id: str, @@ -232,6 +234,7 @@ class MCPTool(Tool): return found return None + @override def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index 3fbbd4c9e5..fc6ec14284 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_runtime import ToolRuntime @@ -23,6 +23,7 @@ class PluginToolProviderController(BuiltinToolProviderController): self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> ToolProviderType: """ returns the type of the provider @@ -31,6 +32,7 @@ class PluginToolProviderController(BuiltinToolProviderController): """ return ToolProviderType.PLUGIN + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ validate the credentials of the provider diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d3a2ad488c..ac17d542bc 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Generator -from typing import Any +from typing import Any, override from core.plugin.impl.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -20,9 +20,11 @@ class PluginTool(Tool): self.plugin_unique_identifier = plugin_unique_identifier self.runtime_parameters: list[ToolParameter] | None = None + @override def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN + @override def _invoke( self, user_id: str, @@ -48,6 +50,7 @@ class PluginTool(Tool): message_id=message_id, ) + @override def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool: return PluginTool( entity=self.entity, @@ -57,6 +60,7 @@ class PluginTool(Tool): plugin_unique_identifier=self.plugin_unique_identifier, ) + @override def get_runtime_parameters( self, conversation_id: str | None = None, diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index b6890b2611..beb8c5d005 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,4 +1,5 @@ import threading +from typing import override from flask import Flask, current_app from pydantic import BaseModel, Field @@ -46,6 +47,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs ) + @override def _run(self, query: str) -> str: threads = [] all_documents: list[RagDocument] = [] diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 0d1dc7273b..85a6e57b4c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, cast, override from pydantic import BaseModel, Field from sqlalchemy import select @@ -56,6 +56,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): **kwargs, ) + @override def _run(self, query: str) -> str: dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id) dataset = db.session.scalar(dataset_stmt) diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index 0bdc3df869..d34bfe0aa1 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, override from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -85,6 +85,7 @@ class DatasetRetrieverTool(Tool): return tools + @override def get_runtime_parameters( self, conversation_id: str | None = None, @@ -105,9 +106,11 @@ class DatasetRetrieverTool(Tool): ), ] + @override def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL + @override def _invoke( self, user_id: str, diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 5905fd919e..41212bcec8 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Mapping +from typing import override from pydantic import Field from sqlalchemy import select @@ -80,6 +81,7 @@ class WorkflowToolProviderController(ToolProviderController): return controller @property + @override def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 58686b08ce..7e7b1e3300 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any, cast, override from sqlalchemy import select @@ -67,6 +67,7 @@ class WorkflowTool(Tool): super().__init__(entity=entity, runtime=runtime) + @override def tool_provider_type(self) -> ToolProviderType: """ get the tool provider type @@ -75,6 +76,7 @@ class WorkflowTool(Tool): """ return ToolProviderType.WORKFLOW + @override def _invoke( self, user_id: str, @@ -206,6 +208,7 @@ class WorkflowTool(Tool): return found return None + @override def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool: """ fork a new tool with metadata diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 24c1271488..c5caee0b56 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -6,7 +6,7 @@ import time from abc import ABC, abstractmethod from collections.abc import Mapping from datetime import datetime -from typing import Any +from typing import Any, override from pydantic import BaseModel @@ -62,6 +62,7 @@ class TriggerDebugEventPoller(ABC): class PluginTriggerDebugEventPoller(TriggerDebugEventPoller): + @override def poll(self) -> TriggerDebugEvent | None: from services.trigger.trigger_service import TriggerService @@ -103,6 +104,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller): class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller): + @override def poll(self) -> TriggerDebugEvent | None: pool_key = build_webhook_pool_key( tenant_id=self.tenant_id, @@ -190,6 +192,7 @@ class ScheduleTriggerDebugEventPoller(TriggerDebugEventPoller): inputs={}, ) + @override def poll(self) -> TriggerDebugEvent | None: schedule_debug_runtime = self.get_or_create_schedule_debug_runtime() if schedule_debug_runtime.next_run_at > naive_utc_now(): diff --git a/api/core/trigger/utils/encryption.py b/api/core/trigger/utils/encryption.py index b12291e299..9b958690e5 100644 --- a/api/core/trigger/utils/encryption.py +++ b/api/core/trigger/utils/encryption.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Union +from typing import Union, override from core.entities.provider_entities import BasicProviderConfig, ProviderConfig from core.helper.provider_cache import ProviderCredentialsCache @@ -16,6 +16,7 @@ class TriggerProviderCredentialsCache(ProviderCredentialsCache): def __init__(self, tenant_id: str, provider_id: str, credential_id: str): super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider_id = kwargs["provider_id"] @@ -29,6 +30,7 @@ class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache): def __init__(self, tenant_id: str, provider_id: str): super().__init__(tenant_id=tenant_id, provider_id=provider_id) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider_id = kwargs["provider_id"] @@ -41,6 +43,7 @@ class TriggerProviderPropertiesCache(ProviderCredentialsCache): def __init__(self, tenant_id: str, provider_id: str, subscription_id: str): super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider_id = kwargs["provider_id"] diff --git a/api/core/workflow/human_input_adapter.py b/api/core/workflow/human_input_adapter.py index 731ae2b858..0865365ea6 100644 --- a/api/core/workflow/human_input_adapter.py +++ b/api/core/workflow/human_input_adapter.py @@ -10,7 +10,7 @@ from __future__ import annotations import enum import uuid from collections.abc import Mapping, Sequence -from typing import Annotated, Any, ClassVar, Literal +from typing import Annotated, Any, ClassVar, Literal, override import bleach import markdown @@ -158,6 +158,7 @@ class EmailDeliveryMethod(_DeliveryMethodBase): type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL config: EmailDeliveryConfig + @override def extract_variable_selectors(self) -> Sequence[Sequence[str]]: variable_template_parser = VariableTemplateParser(template=self.config.body) selectors: list[Sequence[str]] = [] diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index d02a45a944..ae578788ea 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -195,13 +195,16 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod snapshot.update(self._overrides) return snapshot + @override def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]: return self._snapshot()[key] + @override def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None: self._deleted.discard(key) self._overrides[key] = value + @override def __delitem__(self, key: NodeType) -> None: if key in self._overrides: del self._overrides[key] @@ -211,9 +214,11 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod return raise KeyError(key) + @override def __iter__(self) -> Iterator[NodeType]: return iter(self._snapshot()) + @override def __len__(self) -> int: return len(self._snapshot()) diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index ef3f3d5a6b..8c7d5c157c 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload, override from sqlalchemy import select from sqlalchemy.orm import Session @@ -136,6 +136,7 @@ class DifyFileReferenceFactory(FileReferenceFactoryProtocol): def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: self._run_context = resolve_dify_run_context(run_context) + @override def build_from_mapping(self, *, mapping: Mapping[str, Any]): return file_factory.build_from_mapping( mapping=mapping, @@ -151,25 +152,31 @@ class DifyPreparedLLM(LLMProtocol): self._model_instance = model_instance @property + @override def provider(self) -> str: return self._model_instance.provider @property + @override def model_name(self) -> str: return self._model_instance.model_name @property + @override def parameters(self) -> Mapping[str, Any]: return self._model_instance.parameters @parameters.setter + @override def parameters(self, value: Mapping[str, Any]) -> None: self._model_instance.parameters = value @property + @override def stop(self) -> Sequence[str] | None: return self._model_instance.stop + @override def get_model_schema(self) -> AIModelEntity: model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( self._model_instance.model_name, @@ -179,6 +186,7 @@ class DifyPreparedLLM(LLMProtocol): raise ValueError(f"Model schema not found for {self._model_instance.model_name}") return model_schema + @override def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) @@ -204,6 +212,7 @@ class DifyPreparedLLM(LLMProtocol): stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @override def invoke_llm( self, *, @@ -243,6 +252,7 @@ class DifyPreparedLLM(LLMProtocol): stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @override def invoke_llm_with_structured_output( self, *, @@ -263,11 +273,13 @@ class DifyPreparedLLM(LLMProtocol): stream=stream, ) + @override def is_structured_output_parse_error(self, error: Exception) -> bool: return isinstance(error, OutputParserError) class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + @override def serialize( self, *, @@ -294,6 +306,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): self._file_reference_factory = file_reference_factory self._segment_access_checker = segment_access_checker + @override def load(self, *, segment_id: str) -> Sequence[File]: if not is_retriever_segment_access_granted(segment_id): return [] @@ -341,6 +354,7 @@ class DifyToolFileManager(ToolFileManagerProtocol): self._manager = ToolFileManager() self._conversation_id_getter = conversation_id_getter + @override def create_file_by_raw( self, *, @@ -358,6 +372,7 @@ class DifyToolFileManager(ToolFileManagerProtocol): filename=filename, ) + @override def get_file_generator_by_tool_file_id(self, tool_file_id: str): return self._manager.get_file_generator_by_tool_file_id(tool_file_id) @@ -394,9 +409,11 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): def file_reference_factory(self) -> FileReferenceFactoryProtocol: return self._file_reference_factory + @override def build_file_reference(self, *, mapping: Mapping[str, Any]): return self._file_reference_factory.build_from_mapping(mapping=mapping) + @override def get_runtime( self, *, @@ -447,6 +464,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): ) ) + @override def get_runtime_parameters( self, *, @@ -458,6 +476,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): for parameter in (tool.get_merged_runtime_parameters() or []) ] + @override def invoke( self, *, @@ -503,6 +522,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): return self._adapt_messages(transformed_messages, provider_name=provider_name) + @override def get_usage( self, *, @@ -745,6 +765,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): form_repository=form_repository, ) + @override def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: repo = self.build_form_repository() return repo.get_form(node_id) @@ -766,6 +787,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): ) return restored_data + @override def create_form( self, *, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 17d71668cb..2b6745d46a 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.workflow.system_variables import SystemVariableKey, get_system_text @@ -56,9 +56,11 @@ class AgentNode(Node[AgentNodeData]): self._message_transformer = message_transformer @classmethod + @override def version(cls) -> str: return "1" + @override def populate_start_event(self, event) -> None: dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { @@ -69,6 +71,7 @@ class AgentNode(Node[AgentNodeData]): ), } + @override def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError @@ -167,6 +170,7 @@ class AgentNode(Node[AgentNodeData]): ) @classmethod + @override def _extract_variable_selector_to_variable_mapping( cls, *, diff --git a/api/core/workflow/nodes/agent/plugin_strategy_adapter.py b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py index 1fc427ad6c..4a9a877d75 100644 --- a/api/core/workflow/nodes/agent/plugin_strategy_adapter.py +++ b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py @@ -1,11 +1,14 @@ from __future__ import annotations +from typing import override + from factories.agent_factory import get_plugin_agent_strategy from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy class PluginAgentStrategyResolver(AgentStrategyResolver): + @override def resolve( self, *, @@ -21,6 +24,7 @@ class PluginAgentStrategyResolver(AgentStrategyResolver): class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider): + @override def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: from core.plugin.impl.plugin import PluginInstaller diff --git a/api/core/workflow/nodes/agent_v2/agent_node.py b/api/core/workflow/nodes/agent_v2/agent_node.py index a255d3a143..cf5baf251a 100644 --- a/api/core/workflow/nodes/agent_v2/agent_node.py +++ b/api/core/workflow/nodes/agent_v2/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from agenton.compositor import CompositorSessionSnapshot @@ -101,12 +101,15 @@ class DifyAgentNode(Node[DifyAgentNodeData]): self._session_store = session_store @classmethod + @override def version(cls) -> str: return "2" + @override def populate_start_event(self, event) -> None: event.extras["agent_node"] = {"version": "2", "agent_node_kind": self.node_data.agent_node_kind} + @override def _run(self) -> Generator[NodeEventBase, None, None]: dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) workflow_id = self.graph_init_params.workflow_id @@ -577,6 +580,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]): metadata["agent_backend"] = agent_backend @classmethod + @override def _extract_variable_selector_to_variable_mapping( cls, *, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index a4ef3d1ea7..f2ded99db3 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager @@ -49,10 +49,12 @@ class DatasourceNode(Node[DatasourceNodeData]): ) self.datasource_manager = DatasourceManager + @override def populate_start_event(self, event) -> None: event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}" event.provider_type = self.node_data.provider_type + @override def _run(self) -> Generator: """ Run the datasource node @@ -183,6 +185,7 @@ class DatasourceNode(Node[DatasourceNodeData]): ) @classmethod + @override def _extract_variable_selector_to_variable_mapping( cls, *, @@ -219,5 +222,6 @@ class DatasourceNode(Node[DatasourceNodeData]): return result @classmethod + @override def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index ac3a194cdd..86854c0182 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict @@ -46,6 +46,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() + @override def _run(self) -> NodeRunResult: node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool @@ -145,6 +146,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): return rst @classmethod + @override def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e161796210..1d938dd04c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -6,7 +6,7 @@ the workflow node registry. import logging from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, override from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext @@ -87,9 +87,11 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD self._rag_retrieval = DatasetRetrieval() @classmethod + @override def version(cls): return "1" + @override def _run(self) -> NodeRunResult: usage = LLMUsage.empty_usage() if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: @@ -327,6 +329,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) @classmethod + @override def _extract_variable_selector_to_variable_mapping( cls, *, diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index c848a86255..a3641c34f3 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, override from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID @@ -15,6 +15,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]): execution_type = NodeExecutionType.ROOT @classmethod + @override def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "plugin", @@ -30,12 +31,15 @@ class TriggerEventNode(Node[TriggerEventNodeData]): } @classmethod + @override def version(cls) -> str: return "1" + @override def populate_start_event(self, event) -> None: event.provider_id = self.node_data.provider_id + @override def _run(self) -> NodeRunResult: """ Run the plugin trigger node. diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b46cc76a6e..d2cdf96d15 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import override from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID @@ -14,10 +15,12 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): execution_type = NodeExecutionType.ROOT @classmethod + @override def version(cls) -> str: return "1" @classmethod + @override def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": TRIGGER_SCHEDULE_NODE_TYPE, @@ -29,6 +32,7 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): }, } + @override def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 13c4f05bfd..6418027570 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any +from typing import Any, override from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.file_reference import resolve_file_record_id @@ -25,12 +25,14 @@ class TriggerWebhookNode(Node[WebhookData]): _file_reference_factory: FileReferenceFactoryProtocol + @override def post_init(self) -> None: from core.workflow.node_runtime import DifyFileReferenceFactory self._file_reference_factory = DifyFileReferenceFactory(self.run_context) @classmethod + @override def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "webhook", @@ -48,9 +50,11 @@ class TriggerWebhookNode(Node[WebhookData]): } @classmethod + @override def version(cls) -> str: return "1" + @override def _run(self) -> NodeRunResult: """ Run the webhook node. diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index b4ffb37549..5164841ff3 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any +from typing import Any, override from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage @@ -11,6 +11,7 @@ from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderErr class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + @override def render_template(self, template: str, variables: Mapping[str, Any]) -> str: try: result = CodeExecutor.execute_workflow_code_template( diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py index f3e5d6d0d6..ec3570d81e 100644 --- a/api/enterprise/telemetry/id_generator.py +++ b/api/enterprise/telemetry/id_generator.py @@ -1,3 +1,5 @@ +from typing import override + """Custom OTEL ID Generator for correlation-based trace/span ID derivation. Uses contextvars for thread-safe correlation_id -> trace_id mapping. @@ -52,6 +54,7 @@ class CorrelationIdGenerator(IdGenerator): parent-child linking), otherwise random """ + @override def generate_trace_id(self) -> int: correlation_id = _correlation_id_context.get() if correlation_id: @@ -61,6 +64,7 @@ class CorrelationIdGenerator(IdGenerator): pass return random.getrandbits(128) + @override def generate_span_id(self) -> int: source = _span_id_source_context.get() if source: diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 7191933eed..d1a8f0c959 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,5 @@ import json +from typing import override from flask_restx import fields @@ -7,6 +8,7 @@ from libs.helper import AppIconUrlField, TimestampField class JsonStringField(fields.Raw): + @override def format(self, value): if isinstance(value, str): try: diff --git a/api/fields/raws.py b/api/fields/raws.py index ee6f53b360..c7e047626f 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,9 +1,12 @@ +from typing import override + from flask_restx import fields from graphon.file import File class FilesContainedField(fields.Raw): + @override def format(self, value): return self._format_file_object(value) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 6e947858ba..49b7a8be25 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,3 +1,5 @@ +from typing import override + from flask_restx import fields from core.helper import encrypter @@ -11,6 +13,7 @@ ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, class EnvironmentVariableField(fields.Raw): + @override def format(self, value): # Mask secret variables values in environment_variables if isinstance(value, SecretVariable): diff --git a/api/models/base.py b/api/models/base.py index 5acdf184f4..b9c25a9150 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import override from uuid import uuid4 from sqlalchemy import DateTime, func @@ -51,6 +52,7 @@ class DefaultFieldsMixin: onupdate=func.current_timestamp(), ) + @override def __repr__(self) -> str: return f"<{self.__class__.__name__}(id={self.id})>" @@ -87,6 +89,7 @@ class DefaultFieldsDCMixin(MappedAsDataclass): onupdate=func.current_timestamp(), ) + @override def __repr__(self) -> str: return f"<{self.__class__.__name__}(id={self.id})>" diff --git a/api/models/dataset.py b/api/models/dataset.py index 007d24728f..1644551925 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,7 +9,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, ClassVar, TypedDict, cast +from typing import Any, ClassVar, TypedDict, cast, override from uuid import uuid4 import sqlalchemy as sa @@ -1790,5 +1790,6 @@ class DocumentSegmentSummary(TypeBase): init=False, ) + @override def __repr__(self): return f"" diff --git a/api/models/enums.py b/api/models/enums.py index d30d2447db..cdd2b136cf 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,4 +1,5 @@ from enum import StrEnum +from typing import override from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -12,6 +13,7 @@ class CreatorUserRole(StrEnum): END_USER = "end_user" @classmethod + @override def _missing_(cls, value): if value == "end-user": return cls.END_USER diff --git a/api/models/model.py b/api/models/model.py index d20089e49a..96ab7b0e6e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -8,7 +8,7 @@ from datetime import datetime from decimal import Decimal from enum import StrEnum, auto from functools import lru_cache -from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast, override from uuid import uuid4 import sqlalchemy as sa @@ -2058,10 +2058,12 @@ class EndUser(Base, UserMixin): ) @property + @override def is_anonymous(self) -> Literal[False]: return False @is_anonymous.setter + @override def is_anonymous(self, value: bool) -> None: self._is_anonymous = value diff --git a/api/models/provider.py b/api/models/provider.py index 0bc2fc2130..1a4c9db997 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import datetime from enum import StrEnum, auto from functools import cached_property +from typing import override from uuid import uuid4 import sqlalchemy as sa @@ -73,6 +74,7 @@ class Provider(TypeBase): DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) + @override def __repr__(self): return ( f" str: return str(self) + @override def __str__(self) -> str: return f"{self.organization}/{self.plugin_name}/{self.provider_name}" diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index dd323721bb..8a5eef69b9 100644 --- a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -1,5 +1,6 @@ import logging from collections.abc import Sequence +from typing import override from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -74,6 +75,7 @@ class AliyunDataTrace(BaseTraceInstance): endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key) self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint) + @override def trace(self, trace_info: BaseTraceInfo): match trace_info: case WorkflowTraceInfo(): diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index f2217dc0a2..0f24adfd92 100644 --- a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -6,7 +6,7 @@ import traceback from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Any, Protocol, Union, cast +from typing import Any, Protocol, Union, cast, override from urllib.parse import urlparse from openinference.semconv.trace import ( @@ -730,6 +730,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.root_span_carriers: dict[str, dict[str, str]] = {} self.carrier: dict[str, str] = {} + @override def trace(self, trace_info: BaseTraceInfo): logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info) logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info)) diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index f75bf11530..0c2ead2bcc 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -2,7 +2,7 @@ import json from collections.abc import Sequence from datetime import UTC, datetime, timedelta from types import SimpleNamespace -from typing import Any, cast +from typing import Any, cast, override from unittest.mock import MagicMock, patch import dify_trace_arize_phoenix.arize_phoenix_trace as arize_phoenix_trace_module @@ -133,10 +133,12 @@ class _CollectingSpanExporter(SpanExporter): def __init__(self): self.spans: list[ReadableSpan] = [] + @override def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: self.spans.extend(spans) return SpanExportResult.SUCCESS + @override def shutdown(self) -> None: return None @@ -258,9 +260,11 @@ def test_set_span_status(): # repr branch class SilentError: + @override def __str__(self): return "" + @override def __repr__(self): return "SilentErrorRepr" diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index b9756cf00e..caa9b67c66 100644 --- a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -2,6 +2,7 @@ import logging import os import uuid from datetime import UTC, datetime, timedelta +from typing import override from langfuse import Langfuse from langfuse.api import ( @@ -106,6 +107,7 @@ class LangFuseDataTrace(BaseTraceInstance): return start_time + timedelta(seconds=ttft_seconds) + @override def trace(self, trace_info: BaseTraceInfo): match trace_info: case WorkflowTraceInfo(): diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index 0580051f54..3bac908deb 100644 --- a/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -2,6 +2,7 @@ import collections import logging from datetime import UTC, datetime, timedelta from types import SimpleNamespace +from typing import override from unittest.mock import MagicMock import pytest @@ -738,6 +739,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.status = "succeeded" class BadDict(collections.UserDict): + @override def get(self, key, default=None): if key == "usage": raise Exception("Usage extraction failed") diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 0d3695d681..4260e5d6ba 100644 --- a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import cast +from typing import cast, override from langsmith import Client from langsmith.schemas import RunBase @@ -47,6 +47,7 @@ class LangSmithDataTrace(BaseTraceInstance): self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + @override def trace(self, trace_info: BaseTraceInfo): match trace_info: case WorkflowTraceInfo(): diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index edc4aafd87..8336f8f51f 100644 --- a/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -1,5 +1,6 @@ import collections from datetime import datetime, timedelta +from typing import override from unittest.mock import MagicMock import pytest @@ -549,6 +550,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pyte ) class BadDict(collections.UserDict): + @override def get(self, key, default=None): if key == "usage": raise Exception("Usage extraction failed") diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index e36b3dee40..9b9b4f2c15 100644 --- a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -1,7 +1,7 @@ import logging import os from datetime import datetime, timedelta -from typing import Any, cast +from typing import Any, cast, override import mlflow from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType @@ -86,6 +86,7 @@ class MLflowDataTrace(BaseTraceInstance): self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces" + @override def trace(self, trace_info: BaseTraceInfo): """Simple dispatch to trace methods""" try: diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index 33193ac574..c9356ab368 100644 --- a/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -3,7 +3,7 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import Any, cast +from typing import Any, cast, override from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 @@ -95,6 +95,7 @@ class OpikDataTrace(BaseTraceInstance): self.project = opik_config.project self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + @override def trace(self, trace_info: BaseTraceInfo): match trace_info: case WorkflowTraceInfo(): diff --git a/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index 5daaa7132c..3f1fef2c7d 100644 --- a/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -2,6 +2,7 @@ import collections import logging from datetime import UTC, datetime, timedelta from types import SimpleNamespace +from typing import override from unittest.mock import MagicMock import pytest @@ -643,6 +644,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.status = "succeeded" class BadDict(collections.UserDict): + @override def get(self, key, default=None): if key == "usage": raise Exception("Usage extraction failed") diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index 3e1dc1d9f6..305e7851c8 100644 --- a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,3 +1,5 @@ +from typing import override + """Tencent APM tracing with idempotent client cleanup.""" import inspect @@ -56,6 +58,7 @@ class TencentDataTrace(BaseTraceInstance): metrics_export_interval_sec=5, ) + @override def trace(self, trace_info: BaseTraceInfo) -> None: """Main tracing entry point - coordinates different trace types.""" match trace_info: diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index 5a3779de11..03a13a636c 100644 --- a/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -2,7 +2,7 @@ import logging import os import uuid from datetime import UTC, datetime, timedelta -from typing import Any, cast +from typing import Any, cast, override import wandb import weave @@ -78,6 +78,7 @@ class WeaveDataTrace(BaseTraceInstance): logger.debug("Weave get run url failed: %s", str(e)) raise ValueError(f"Weave get run url failed: {str(e)}") + @override def trace(self, trace_info: BaseTraceInfo): logger.debug("Trace info: %s", trace_info) match trace_info: diff --git a/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py index 37ffd11063..c6fe8988d1 100644 --- a/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -3,7 +3,7 @@ import json import logging import uuid from contextlib import contextmanager -from typing import Any, Literal, cast +from typing import Any, Literal, cast, override import mysql.connector from mysql.connector import Error as MySQLError @@ -81,6 +81,7 @@ class AlibabaCloudMySQLVector(BaseVector): self.hnsw_m = config.hnsw_m self._check_vector_support() + @override def get_type(self) -> str: return VectorType.ALIBABACLOUD_MYSQL @@ -135,11 +136,13 @@ class AlibabaCloudMySQLVector(BaseVector): cur.close() conn.close() + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) return self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): values = [] pks = [] @@ -165,6 +168,7 @@ class AlibabaCloudMySQLVector(BaseVector): cur.executemany(insert_sql, values) return pks + @override def text_exists(self, id: str) -> bool: with self._get_cursor() as cur: cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) @@ -183,6 +187,7 @@ class AlibabaCloudMySQLVector(BaseVector): docs.append(Document(page_content=record["text"], metadata=metadata)) return docs + @override def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists if not ids: @@ -199,12 +204,14 @@ class AlibabaCloudMySQLVector(BaseVector): else: raise e + @override def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute( f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value) ) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector using RDS MySQL vector distance functions. @@ -274,6 +281,7 @@ class AlibabaCloudMySQLVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) if not isinstance(top_k, int) or top_k <= 0: @@ -308,6 +316,7 @@ class AlibabaCloudMySQLVector(BaseVector): docs.append(Document(page_content=record["text"], metadata=metadata)) return docs + @override def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") @@ -355,6 +364,7 @@ class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory): raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'") return cast(Literal["cosine", "euclidean"], distance_function) + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py index 54eeb78ca9..e554a6f9cb 100644 --- a/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py @@ -1,5 +1,6 @@ import json import unittest +from typing import override from unittest.mock import MagicMock, patch import pytest @@ -22,6 +23,7 @@ except ImportError: class TestAlibabaCloudMySQLVector(unittest.TestCase): + @override def setUp(self): self.config = AlibabaCloudMySQLVectorConfig( host="localhost", diff --git a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py index e56bb74ba3..62d8819fe6 100644 --- a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, override from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -32,38 +32,48 @@ class AnalyticdbVector(BaseVector): raise ValueError("Either api_config or sql_config must be provided") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) + @override def get_type(self) -> str: return VectorType.ANALYTICDB + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self.analyticdb_vector.create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: self.analyticdb_vector.add_texts(documents, embeddings) return [] + @override def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) + @override def delete_by_ids(self, ids: list[str]): self.analyticdb_vector.delete_by_ids(ids) + @override def delete_by_metadata_field(self, key: str, value: str): self.analyticdb_vector.delete_by_metadata_field(key, value) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: return self.analyticdb_vector.search_by_vector(query_vector, **kwargs) + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self.analyticdb_vector.search_by_full_text(query, **kwargs) + @override def delete(self): self.analyticdb_vector.delete() class AnalyticdbVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py b/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py index 2bb413dcc1..d318dd8827 100644 --- a/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py +++ b/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig @@ -40,6 +42,7 @@ class AnalyticdbVectorTest(AbstractVectorTest): ), ) + @override def run_all_tests(self): self.vector.delete() return super().run_all_tests() diff --git a/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py index bdd5a42c87..2d8caccc9b 100644 --- a/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py +++ b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py @@ -2,7 +2,7 @@ import json import logging import time import uuid -from typing import Any +from typing import Any, override import numpy as np from pydantic import BaseModel, model_validator @@ -82,6 +82,7 @@ class BaiduVector(BaseVector): self._client = self._init_client(config) self._db = self._init_database() + @override def get_type(self) -> str: return VectorType.BAIDU @@ -92,10 +93,12 @@ class BaiduVector(BaseVector): } return result + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self._create_table(len(embeddings[0])) self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): total_count = len(documents) batch_size = 1000 @@ -116,23 +119,27 @@ class BaiduVector(BaseVector): rows.append(row) table.upsert(rows=rows) + @override def text_exists(self, id: str) -> bool: res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: return True return False + @override def delete_by_ids(self, ids: list[str]): if not ids: return quoted_ids = [f"'{id}'" for id in ids] self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})") + @override def delete_by_metadata_field(self, key: str, value: str): # Escape double quotes in value to prevent injection escaped_value = value.replace('"', '\\"') self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"') + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] document_ids_filter = kwargs.get("document_ids_filter") @@ -154,6 +161,7 @@ class BaiduVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # document ids filter document_ids_filter = kwargs.get("document_ids_filter") @@ -189,6 +197,7 @@ class BaiduVector(BaseVector): docs.append(doc) return docs + @override def delete(self): try: self._db.drop_table(table_name=self._collection_name) @@ -368,6 +377,7 @@ class BaiduVector(BaseVector): class BaiduVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py b/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py index 2c1d0e3554..2138901d2d 100644 --- a/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py +++ b/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text @@ -18,10 +20,12 @@ class BaiduVectorTest(AbstractVectorTest): ), ) + @override def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 + @override def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 diff --git a/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py index 754b1e8a89..e92ac94afa 100644 --- a/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py +++ b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any, TypedDict +from typing import Any, TypedDict, override import chromadb from chromadb import QueryResult, Settings @@ -55,9 +55,11 @@ class ChromaVector(BaseVector): self._client_config = config self._client = chromadb.HttpClient(**self._client_config.to_chroma_params()) + @override def get_type(self) -> str: return VectorType.CHROMA + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: # create collection @@ -74,6 +76,7 @@ class ChromaVector(BaseVector): self._client.get_or_create_collection(collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] @@ -84,25 +87,30 @@ class ChromaVector(BaseVector): collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore return uuids + @override def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) # FIXME: fix the type error later collection.delete(where={key: {"$eq": value}}) # type: ignore + @override def delete(self): self._client.delete_collection(self._collection_name) + @override def delete_by_ids(self, ids: list[str]): if not ids: return collection = self._client.get_or_create_collection(self._collection_name) collection.delete(ids=ids) + @override def text_exists(self, id: str) -> bool: collection = self._client.get_or_create_collection(self._collection_name) response = collection.get(ids=[id]) return len(response) > 0 + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) document_ids_filter = kwargs.get("document_ids_filter") @@ -142,12 +150,14 @@ class ChromaVector(BaseVector): docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # chroma does not support BM25 full text searching return [] class ChromaVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py b/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py index abd85b885d..cb310c0986 100644 --- a/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py +++ b/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py @@ -1,3 +1,5 @@ +from typing import override + import chromadb from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector @@ -22,6 +24,7 @@ class ChromaVectorTest(AbstractVectorTest): ), ) + @override def search_by_full_text(self): # chroma dos not support full text searching hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) diff --git a/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py index 72b8c5e9eb..6231b9a9fa 100644 --- a/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py +++ b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py @@ -8,7 +8,7 @@ import re import threading import time import uuid -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override import clickzetta # type: ignore from pydantic import BaseModel, model_validator @@ -436,6 +436,7 @@ class ClickzettaVector(BaseVector): raise result return result + @override def get_type(self) -> str: """Return the vector database type.""" return "clickzetta" @@ -469,6 +470,7 @@ class ClickzettaVector(BaseVector): ) return False + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """Create the collection and add initial documents.""" # Execute table creation through write queue to avoid concurrent conflicts @@ -606,6 +608,7 @@ class ClickzettaVector(BaseVector): logger.warning("Failed to create inverted index: %s", e) # Continue without inverted index - full-text search will fall back to LIKE + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: """Add documents with embeddings to the collection.""" if not documents: @@ -716,6 +719,7 @@ class ClickzettaVector(BaseVector): logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") raise + @override def text_exists(self, id: str) -> bool: """Check if a document exists by ID.""" # Check if table exists first @@ -732,6 +736,7 @@ class ClickzettaVector(BaseVector): result = cursor.fetchone() return result[0] > 0 if result else False + @override def delete_by_ids(self, ids: list[str]): """Delete documents by IDs.""" if not ids: @@ -757,6 +762,7 @@ class ClickzettaVector(BaseVector): with connection.cursor() as cursor: cursor.execute(sql, binding_params=safe_ids) + @override def delete_by_metadata_field(self, key: str, value: str): """Delete documents by metadata field.""" # Check if table exists before attempting delete @@ -780,6 +786,7 @@ class ClickzettaVector(BaseVector): ) cursor.execute(sql, binding_params=[value]) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Search for documents by vector similarity.""" # Check if table exists first @@ -865,6 +872,7 @@ class ClickzettaVector(BaseVector): return documents + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """Search for documents using full-text search with inverted index.""" if not self._config.enable_inverted_index: @@ -1031,6 +1039,7 @@ class ClickzettaVector(BaseVector): return documents + @override def delete(self): """Delete the entire collection.""" with self.get_connection_context() as connection: @@ -1057,6 +1066,7 @@ class ClickzettaVector(BaseVector): class ClickzettaVectorFactory(AbstractVectorFactory): """Factory for creating Clickzetta vector instances.""" + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: """Initialize a Clickzetta vector instance.""" # Get configuration from environment variables or dataset config diff --git a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py index 5deac59dc9..7ac7edc6a0 100644 --- a/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py @@ -3,7 +3,7 @@ import logging import time import uuid from datetime import timedelta -from typing import Any +from typing import Any, override from couchbase import search # type: ignore from couchbase.auth import PasswordAuthenticator # type: ignore @@ -68,6 +68,7 @@ class CouchbaseVector(BaseVector): # Wait until the cluster is ready for use. self._cluster.wait_until_ready(timedelta(seconds=5)) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_id = str(uuid.uuid4()).replace("-", "") self._create_collection(uuid=index_id, vector_length=len(embeddings[0])) @@ -200,9 +201,11 @@ class CouchbaseVector(BaseVector): # Check if the collection exists in the scope return self._collection_name in scope_collection_map[self._scope_name] + @override def get_type(self) -> str: return VectorType.COUCHBASE + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] @@ -221,6 +224,7 @@ class CouchbaseVector(BaseVector): return doc_ids + @override def text_exists(self, id: str) -> bool: # Use a parameterized query for safety and correctness query = f""" @@ -234,6 +238,7 @@ class CouchbaseVector(BaseVector): return bool(row["count"] > 0) return False # Return False if no rows are returned + @override def delete_by_ids(self, ids: list[str]): query = f""" DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name} @@ -261,6 +266,7 @@ class CouchbaseVector(BaseVector): # result = self._cluster.query(query, named_parameters={'value':value}) # return [row['id'] for row in result.rows()] + @override def delete_by_metadata_field(self, key: str, value: str): query = f""" DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} @@ -268,6 +274,7 @@ class CouchbaseVector(BaseVector): """ self._cluster.query(query, named_parameters={"value": value}).execute() + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) score_threshold = kwargs.get("score_threshold") or 0.0 @@ -303,6 +310,7 @@ class CouchbaseVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) try: @@ -325,6 +333,7 @@ class CouchbaseVector(BaseVector): return docs + @override def delete(self): manager = self._bucket.collections() scopes = manager.get_all_scopes() @@ -356,6 +365,7 @@ class CouchbaseVector(BaseVector): class CouchbaseVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py b/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py index 918dae328f..8b45734281 100644 --- a/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py +++ b/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py @@ -1,6 +1,7 @@ import logging import subprocess import time +from typing import override from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector @@ -40,6 +41,7 @@ class CouchbaseTest(AbstractVectorTest): ), ) + @override def search_by_vector(self): # brief sleep to ensure document is indexed time.sleep(5) diff --git a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py index e2f390402a..cff39830f8 100644 --- a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, override from flask import current_app @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) class ElasticSearchJaVector(ElasticSearchVector): + @override def create_collection( self, embeddings: list[list[float]], @@ -82,6 +83,7 @@ class ElasticSearchJaVector(ElasticSearchVector): class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py index 11463b6c58..b83cddb057 100644 --- a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, cast +from typing import Any, cast, override from urllib.parse import urlparse from elasticsearch import ConnectionError as ElasticsearchConnectionError @@ -154,9 +154,11 @@ class ElasticSearchVector(BaseVector): if parse_version(self._version) < parse_version("8.0.0"): raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + @override def get_type(self) -> str: return VectorType.ELASTICSEARCH + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) for i in range(len(documents)): @@ -172,15 +174,18 @@ class ElasticSearchVector(BaseVector): self._client.indices.refresh(index=self._collection_name) return uuids + @override def text_exists(self, id: str) -> bool: return bool(self._client.exists(index=self._collection_name, id=id)) + @override def delete_by_ids(self, ids: list[str]): if not ids: return for id in ids: self._client.delete(index=self._collection_name, id=id) + @override def delete_by_metadata_field(self, key: str, value: str): query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) @@ -188,9 +193,11 @@ class ElasticSearchVector(BaseVector): if ids: self.delete_by_ids(ids) + @override def delete(self): self._client.indices.delete(index=self._collection_name) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) num_candidates = math.ceil(top_k * 1.5) @@ -224,6 +231,7 @@ class ElasticSearchVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}} document_ids_filter = kwargs.get("document_ids_filter") @@ -249,6 +257,7 @@ class ElasticSearchVector(BaseVector): return docs + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) @@ -297,6 +306,7 @@ class ElasticSearchVector(BaseVector): class ElasticSearchVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py index 80c0ed582e..abaacec35a 100644 --- a/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py +++ b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any, cast +from typing import Any, cast, override import holo_search_sdk as holo # type: ignore from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType @@ -81,15 +81,18 @@ class HologresVector(BaseVector): client.connect() return client + @override def get_type(self) -> str: return VectorType.HOLOGRES + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """Create collection table with vector and full-text indexes, then add texts.""" dimension = len(embeddings[0]) self._create_collection(dimension) self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): """Add texts with embeddings to the collection using batch upsert.""" if not documents: @@ -127,6 +130,7 @@ class HologresVector(BaseVector): return pks + @override def text_exists(self, id: str) -> bool: """Check if a text with the given doc_id exists in the collection.""" if not self._client.check_table_exist(self.table_name): @@ -140,6 +144,7 @@ class HologresVector(BaseVector): ) return bool(result) + @override def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None: """Get document IDs by metadata field key and value.""" result = self._client.execute( @@ -152,6 +157,7 @@ class HologresVector(BaseVector): return [row[0] for row in result] return None + @override def delete_by_ids(self, ids: list[str]): """Delete documents by their doc_id list.""" if not ids: @@ -166,6 +172,7 @@ class HologresVector(BaseVector): ) ) + @override def delete_by_metadata_field(self, key: str, value: str): """Delete documents by metadata field key and value.""" if not self._client.check_table_exist(self.table_name): @@ -177,6 +184,7 @@ class HologresVector(BaseVector): ) ) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Search for documents by vector similarity.""" if not self._client.check_table_exist(self.table_name): @@ -229,6 +237,7 @@ class HologresVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """Search for documents by full-text search.""" if not self._client.check_table_exist(self.table_name): @@ -272,6 +281,7 @@ class HologresVector(BaseVector): return docs + @override def delete(self): """Delete the entire collection table.""" if self._client.check_table_exist(self.table_name): @@ -333,6 +343,7 @@ class HologresVector(BaseVector): class HologresVectorFactory(AbstractVectorFactory): """Factory class for creating HologresVector instances.""" + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py b/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py index 04024be4ae..894066cae4 100644 --- a/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py +++ b/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py @@ -1,6 +1,6 @@ import os import uuid -from typing import cast +from typing import cast, override from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType @@ -35,6 +35,7 @@ class HologresVectorTest(AbstractVectorTest): ), ) + @override def search_by_full_text(self): """Override: full-text index may not be immediately ready in real mode.""" hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) @@ -97,12 +98,14 @@ class HologresVectorTest(AbstractVectorTest): assert len(hits) == 1 assert hits[0].metadata["doc_id"] == self.example_doc_id + @override def get_ids_by_metadata_field(self): """Override: Hologres implements this method via JSONB query.""" ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert ids is not None assert len(ids) == 1 + @override def run_all_tests(self): # Clean up before running tests self.vector.delete() diff --git a/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py index d51075d2e8..6f42ce1dbe 100644 --- a/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py +++ b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py @@ -1,7 +1,7 @@ import json import logging import ssl -from typing import Any +from typing import Any, override from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator @@ -68,9 +68,11 @@ class HuaweiCloudVector(BaseVector): super().__init__(index_name.lower()) self._client = Elasticsearch(**config.to_elasticsearch_params()) + @override def get_type(self) -> str: return VectorType.HUAWEI_CLOUD + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) for i in range(len(documents)): @@ -86,15 +88,18 @@ class HuaweiCloudVector(BaseVector): self._client.indices.refresh(index=self._collection_name) return uuids + @override def text_exists(self, id: str) -> bool: return bool(self._client.exists(index=self._collection_name, id=id)) + @override def delete_by_ids(self, ids: list[str]): if not ids: return for id in ids: self._client.delete(index=self._collection_name, id=id) + @override def delete_by_metadata_field(self, key: str, value: str): query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) @@ -102,9 +107,11 @@ class HuaweiCloudVector(BaseVector): if ids: self.delete_by_ids(ids) + @override def delete(self): self._client.indices.delete(index=self._collection_name) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) @@ -145,6 +152,7 @@ class HuaweiCloudVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_str = {"match": {Field.CONTENT_KEY: query}} results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) @@ -160,6 +168,7 @@ class HuaweiCloudVector(BaseVector): return docs + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) @@ -207,6 +216,7 @@ class HuaweiCloudVector(BaseVector): class HuaweiCloudVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HuaweiCloudVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py index bb5f5b72ef..8a7214231f 100644 --- a/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py +++ b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text @@ -15,10 +17,12 @@ class HuaweiCloudVectorTest(AbstractVectorTest): ), ) + @override def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 3 + @override def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 3 diff --git a/api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py b/api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py index aae445e6ff..0326cc646a 100644 --- a/api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py +++ b/api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py @@ -11,7 +11,7 @@ import logging import threading import uuid from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from configs import dify_config from configs.middleware.vdb.iris_config import IrisVectorConfig @@ -188,6 +188,7 @@ class IrisVector(BaseVector): self.schema = config.IRIS_SCHEMA or "dify" self.pool = get_iris_pool(config) + @override def get_type(self) -> str: return VectorType.IRIS @@ -206,11 +207,13 @@ class IrisVector(BaseVector): cursor.close() self.pool.return_connection(conn) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: dimension = len(embeddings[0]) self._create_collection(dimension) return self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]: """Add documents with embeddings to the collection.""" added_ids = [] @@ -226,6 +229,7 @@ class IrisVector(BaseVector): return added_ids + @override def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin try: with self._get_cursor() as cursor: @@ -235,6 +239,7 @@ class IrisVector(BaseVector): except (OSError, RuntimeError, ValueError): return False + @override def delete_by_ids(self, ids: list[str]) -> None: if not ids: return @@ -244,6 +249,7 @@ class IrisVector(BaseVector): sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})" cursor.execute(sql, ids) + @override def delete_by_metadata_field(self, key: str, value: str) -> None: """Delete documents by metadata field (JSON LIKE pattern matching).""" with self._get_cursor() as cursor: @@ -251,6 +257,7 @@ class IrisVector(BaseVector): sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?" cursor.execute(sql, (pattern,)) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Search similar documents using VECTOR_COSINE with HNSW index.""" top_k = kwargs.get("top_k", 4) @@ -275,6 +282,7 @@ class IrisVector(BaseVector): docs.append(Document(page_content=text, metadata=metadata)) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """Search documents by full-text using iFind index with BM25 relevance scoring. @@ -404,6 +412,7 @@ class IrisVector(BaseVector): return docs + @override def delete(self) -> None: """Delete the entire collection (drop table - permanent).""" with self._get_cursor() as cursor: @@ -481,6 +490,7 @@ class IrisVector(BaseVector): class IrisVectorFactory(AbstractVectorFactory): """Factory for creating IrisVector instances.""" + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py index 9187ca943d..0e37cf6e4a 100644 --- a/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any +from typing import Any, override from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError @@ -79,9 +79,11 @@ class LindormVectorStore(BaseVector): self._using_ugc = using_ugc self.kwargs = kwargs + @override def get_type(self) -> str: return VectorType.LINDORM + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) @@ -90,6 +92,7 @@ class LindormVectorStore(BaseVector): def refresh(self): self._client.indices.refresh(index=self._collection_name) + @override def add_texts( self, documents: list[Document], @@ -156,6 +159,7 @@ class LindormVectorStore(BaseVector): logger.exception("Failed to process batch %s", batch_num + 1) raise + @override def get_ids_by_metadata_field(self, key: str, value: str): query: dict[str, Any] = { "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}} @@ -168,11 +172,13 @@ class LindormVectorStore(BaseVector): else: return None + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_ids(ids) + @override def delete_by_ids(self, ids: list[str]): """Delete documents by their IDs in batch. @@ -219,6 +225,7 @@ class LindormVectorStore(BaseVector): else: logger.exception("Error deleting document: %s", error) + @override def delete(self): if self._using_ugc: routing_filter_query = { @@ -233,6 +240,7 @@ class LindormVectorStore(BaseVector): else: logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name) + @override def text_exists(self, id: str) -> bool: try: params: dict[str, Any] = {} @@ -243,6 +251,7 @@ class LindormVectorStore(BaseVector): except: return False + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: if not isinstance(query_vector, list): raise ValueError("query_vector should be a list of floats") @@ -305,6 +314,7 @@ class LindormVectorStore(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}} filters = [] @@ -377,6 +387,7 @@ class LindormVectorStore(BaseVector): class LindormVectorStoreFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: lindorm_config = LindormVectorStoreConfig( hosts=dify_config.LINDORM_URL, diff --git a/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py b/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py index 0a0c2d2d59..e8fd4b29fe 100644 --- a/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py +++ b/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py @@ -1,4 +1,5 @@ import os +from typing import override from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig @@ -26,6 +27,7 @@ class TestLindormVectorStore(AbstractVectorTest): ), ) + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) assert ids is not None @@ -47,6 +49,7 @@ class TestLindormVectorStoreUGC(AbstractVectorTest): routing_value=self.collection_name, ) + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) assert ids is not None diff --git a/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py index 75fb54e6f4..f9fc7f45b7 100644 --- a/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py +++ b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py @@ -3,7 +3,7 @@ import logging import uuid from collections.abc import Callable from functools import wraps -from typing import Any, Concatenate +from typing import Any, Concatenate, override from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -69,16 +69,20 @@ class MatrixoneVector(BaseVector): self.client = None @property + @override def collection_name(self): return self._collection_name @collection_name.setter + @override def collection_name(self, value): self._collection_name = value + @override def get_type(self) -> str: return VectorType.MATRIXONE + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if self.client is None: self.client = self._get_client(len(embeddings[0]), True) @@ -108,6 +112,7 @@ class MatrixoneVector(BaseVector): logger.exception("Failed to create full text index") return client + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): if self.client is None: self.client = self._get_client(len(embeddings[0]), True) @@ -126,12 +131,14 @@ class MatrixoneVector(BaseVector): return ids @ensure_client + @override def text_exists(self, id: str) -> bool: assert self.client is not None result = self.client.get(ids=[id]) return len(result) > 0 @ensure_client + @override def delete_by_ids(self, ids: list[str]): assert self.client is not None if not ids: @@ -139,17 +146,20 @@ class MatrixoneVector(BaseVector): self.client.delete(ids=ids) @ensure_client + @override def get_ids_by_metadata_field(self, key: str, value: str): assert self.client is not None results = self.client.query_by_metadata(filter={key: value}) return [result.id for result in results] @ensure_client + @override def delete_by_metadata_field(self, key: str, value: str): assert self.client is not None self.client.delete(filter={key: value}) @ensure_client + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: assert self.client is not None top_k = kwargs.get("top_k", 5) @@ -177,6 +187,7 @@ class MatrixoneVector(BaseVector): return docs @ensure_client + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: assert self.client is not None top_k = kwargs.get("top_k", 5) @@ -207,12 +218,14 @@ class MatrixoneVector(BaseVector): return docs @ensure_client + @override def delete(self): assert self.client is not None self.client.delete() class MatrixoneVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py b/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py index d6f4781e65..cb11a86536 100644 --- a/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py +++ b/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector from core.rag.datasource.vdb.vector_integration_test_support import ( @@ -15,6 +17,7 @@ class MatrixoneVectorTest(AbstractVectorTest): ), ) + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index ac47be7a37..5a8a1b7cd3 100644 --- a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict, cast +from typing import Any, TypedDict, cast, override from packaging import version from pydantic import BaseModel, model_validator @@ -120,12 +120,14 @@ class MilvusVector(BaseVector): logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e)) return False + @override def get_type(self) -> str: """ Get the type of vector storage (Milvus). """ return VectorType.MILVUS + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """ Create a collection and add texts with embeddings. @@ -135,6 +137,7 @@ class MilvusVector(BaseVector): self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): """ Add texts and their embeddings to the collection. @@ -164,6 +167,7 @@ class MilvusVector(BaseVector): raise e return pks + @override def get_ids_by_metadata_field(self, key: str, value: str): """ Get document IDs by metadata field key and value. @@ -176,6 +180,7 @@ class MilvusVector(BaseVector): else: return None + @override def delete_by_metadata_field(self, key: str, value: str): """ Delete documents by metadata field key and value. @@ -185,6 +190,7 @@ class MilvusVector(BaseVector): if ids: self._client.delete(collection_name=self._collection_name, pks=ids) + @override def delete_by_ids(self, ids: list[str]): """ Delete documents by their IDs. @@ -197,6 +203,7 @@ class MilvusVector(BaseVector): ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) + @override def delete(self): """ Delete the entire collection. @@ -204,6 +211,7 @@ class MilvusVector(BaseVector): if self._client.has_collection(self._collection_name): self._client.drop_collection(self._collection_name, None) + @override def text_exists(self, id: str) -> bool: """ Check if a text with the given ID exists in the collection. @@ -245,6 +253,7 @@ class MilvusVector(BaseVector): return docs + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search for documents by vector similarity. @@ -269,6 +278,7 @@ class MilvusVector(BaseVector): score_threshold=float(kwargs.get("score_threshold") or 0.0), ) + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """ Search for documents by full-text search (if hybrid search is enabled). @@ -411,6 +421,7 @@ class MilvusVectorFactory(AbstractVectorFactory): Factory class for creating MilvusVector instances. """ + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: """ Initialize a MilvusVector instance for the given dataset. diff --git a/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py index 084d808bed..8992b0ce8c 100644 --- a/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector from core.rag.datasource.vdb.vector_integration_test_support import ( @@ -18,11 +20,13 @@ class MilvusVectorTest(AbstractVectorTest): ), ) + @override def search_by_full_text(self): # milvus support BM25 full text search after version 2.5.0-beta hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) >= 0 + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py b/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py index 6c62671380..dadd1bb77a 100644 --- a/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py +++ b/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py @@ -2,7 +2,7 @@ import json import logging import uuid from enum import StrEnum -from typing import Any +from typing import Any, override from clickhouse_connect import get_client # type: ignore[import-untyped] from pydantic import BaseModel @@ -46,9 +46,11 @@ class MyScaleVector(BaseVector): ) self._client.command("SET allow_experimental_object_type=1") + @override def get_type(self) -> str: return VectorType.MYSCALE + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) @@ -71,6 +73,7 @@ class MyScaleVector(BaseVector): """ self._client.command(sql) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): ids = [] columns = ["id", "text", "vector", "metadata"] @@ -97,10 +100,12 @@ class MyScaleVector(BaseVector): def escape_str(value: Any) -> str: return "".join(" " if c in {"\\", "'"} else c for c in str(value)) + @override def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") return results.row_count > 0 + @override def delete_by_ids(self, ids: list[str]): if not ids: return @@ -108,20 +113,24 @@ class MyScaleVector(BaseVector): f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" ) + @override def get_ids_by_metadata_field(self, key: str, value: str): rows = self._client.query( f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'" ).result_rows return [row[0] for row in rows] + @override def delete_by_metadata_field(self, key: str, value: str): self._client.command( f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'" ) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs) + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) @@ -156,11 +165,13 @@ class MyScaleVector(BaseVector): logger.exception("Vector search operation failed") return [] + @override def delete(self): self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}") class MyScaleVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py b/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py index 8ea42d5f45..2ea77ad000 100644 --- a/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py +++ b/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector from core.rag.datasource.vdb.vector_integration_test_support import ( @@ -20,6 +22,7 @@ class MyScaleVectorTest(AbstractVectorTest): ), ) + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py index 69dc42169a..93be92a62f 100644 --- a/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py +++ b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Any, Literal +from typing import Any, Literal, override from pydantic import BaseModel, model_validator from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore @@ -86,6 +86,7 @@ class OceanBaseVector(BaseVector): self._load_collection_fields() self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported + @override def get_type(self) -> str: return VectorType.OCEANBASE @@ -114,6 +115,7 @@ class OceanBaseVector(BaseVector): """ return field in self._fields + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self._vec_dim = len(embeddings[0]) self._create_collection() @@ -237,6 +239,7 @@ class OceanBaseVector(BaseVector): logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e)) return False + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): ids = self._get_uuids(documents) batch_size = self._config.batch_size @@ -283,6 +286,7 @@ class OceanBaseVector(BaseVector): self._collection_name, ) + @override def text_exists(self, id: str) -> bool: try: cur = self._client.get(table_name=self._collection_name, ids=id) @@ -295,6 +299,7 @@ class OceanBaseVector(BaseVector): ) raise Exception(f"Failed to check text existence for id '{id}'") from e + @override def delete_by_ids(self, ids: list[str]): if not ids: return @@ -309,6 +314,7 @@ class OceanBaseVector(BaseVector): ) raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e + @override def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: try: import re @@ -343,6 +349,7 @@ class OceanBaseVector(BaseVector): ) raise Exception(f"Failed to query documents by metadata field '{key}'") from e + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: @@ -381,6 +388,7 @@ class OceanBaseVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: if not self._hybrid_search_enabled: logger.warning( @@ -438,6 +446,7 @@ class OceanBaseVector(BaseVector): ) raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from sqlalchemy import text @@ -508,6 +517,7 @@ class OceanBaseVector(BaseVector): return -distance raise ValueError(f"Unsupported metric_type '{metric}'") + @override def delete(self): try: self._client.drop_table_if_exist(self._collection_name) @@ -518,6 +528,7 @@ class OceanBaseVector(BaseVector): class OceanBaseVectorFactory(AbstractVectorFactory): + @override def init_vector( self, dataset: Dataset, diff --git a/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py index 28f22d3cbc..1d2a3a919f 100644 --- a/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py +++ b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py @@ -1,3 +1,5 @@ +from typing import override + import pytest from dify_vdb_oceanbase.oceanbase_vector import ( OceanBaseVector, @@ -30,6 +32,7 @@ class OceanBaseVectorTest(AbstractVectorTest): super().__init__() self.vector = vector + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py index acd2471cf6..70bf5f2a62 100644 --- a/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py +++ b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py @@ -1,7 +1,7 @@ import json import uuid from contextlib import contextmanager -from typing import Any +from typing import Any, override import psycopg2.extras import psycopg2.pool @@ -76,6 +76,7 @@ class OpenGauss(BaseVector): self.table_name = f"embedding_{collection_name}" self.pq_enabled = config.enable_pq + @override def get_type(self) -> str: return VectorType.OPENGAUSS @@ -101,6 +102,7 @@ class OpenGauss(BaseVector): conn.commit() self.pool.putconn(conn) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) @@ -125,6 +127,7 @@ class OpenGauss(BaseVector): cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) redis_client.set(index_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): values = [] pks = [] @@ -146,6 +149,7 @@ class OpenGauss(BaseVector): ) return pks + @override def text_exists(self, id: str) -> bool: with self._get_cursor() as cur: cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) @@ -159,6 +163,7 @@ class OpenGauss(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs + @override def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios # Scenario 1: extract a document fails, resulting in a table not being created. @@ -168,10 +173,12 @@ class OpenGauss(BaseVector): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + @override def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector. @@ -198,6 +205,7 @@ class OpenGauss(BaseVector): docs.append(Document(page_content=text, metadata=metadata)) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) if not isinstance(top_k, int) or top_k <= 0: @@ -222,6 +230,7 @@ class OpenGauss(BaseVector): return docs + @override def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") @@ -240,6 +249,7 @@ class OpenGauss(BaseVector): class OpenGaussFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py index d6998f6672..07652eaebf 100644 --- a/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, override from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -100,14 +100,17 @@ class OpenSearchVector(BaseVector): self._client_config = config self._client = OpenSearch(**config.to_opensearch_params()) + @override def get_type(self) -> str: return VectorType.OPENSEARCH + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): actions = [] for i in range(len(documents)): @@ -132,6 +135,7 @@ class OpenSearchVector(BaseVector): max_retries=3, ) + @override def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) @@ -140,11 +144,13 @@ class OpenSearchVector(BaseVector): else: return None + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_ids(ids) + @override def delete_by_ids(self, ids: list[str]): index_name = self._collection_name.lower() if not self._client.indices.exists(index=index_name): @@ -176,9 +182,11 @@ class OpenSearchVector(BaseVector): else: logger.exception("Error deleting document: %s", error) + @override def delete(self): self._client.indices.delete(index=self._collection_name.lower(), ignore_unavailable=True) + @override def text_exists(self, id: str) -> bool: try: self._client.get(index=self._collection_name.lower(), id=id) @@ -186,6 +194,7 @@ class OpenSearchVector(BaseVector): except: return False + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: # Make sure query_vector is a list if not isinstance(query_vector, list): @@ -234,6 +243,7 @@ class OpenSearchVector(BaseVector): return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}} document_ids_filter = kwargs.get("document_ids_filter") @@ -299,6 +309,7 @@ class OpenSearchVector(BaseVector): class OpenSearchVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py index 5d9ab38529..b8639dae61 100644 --- a/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py +++ b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py @@ -3,7 +3,7 @@ import json import logging import re import uuid -from typing import Any, TypedDict +from typing import Any, TypedDict, override import jieba.posseg as pseg # type: ignore import numpy @@ -87,6 +87,7 @@ class OracleVector(BaseVector): self.table_name = f"embedding_{collection_name}" self.config = config + @override def get_type(self) -> str: return VectorType.ORACLE @@ -153,11 +154,13 @@ class OracleVector(BaseVector): pool_params["wallet_password"] = config.wallet_password return oracledb.create_pool(**pool_params) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) return self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): values = [] pks = [] @@ -196,6 +199,7 @@ class OracleVector(BaseVector): conn.close() return pks + @override def text_exists(self, id: str) -> bool: with self._get_connection() as conn: with conn.cursor() as cur: @@ -217,6 +221,7 @@ class OracleVector(BaseVector): conn.close() return docs + @override def delete_by_ids(self, ids: list[str]): if not ids: return @@ -227,6 +232,7 @@ class OracleVector(BaseVector): conn.commit() conn.close() + @override def delete_by_metadata_field(self, key: str, value: str): with self._get_connection() as conn: with conn.cursor() as cur: @@ -234,6 +240,7 @@ class OracleVector(BaseVector): conn.commit() conn.close() + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector. @@ -277,6 +284,7 @@ class OracleVector(BaseVector): conn.close() return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # lazy import import nltk # type: ignore @@ -347,6 +355,7 @@ class OracleVector(BaseVector): else: return [Document(page_content="", metadata={})] + @override def delete(self): with self._get_connection() as conn: with conn.cursor() as cur: @@ -373,6 +382,7 @@ class OracleVector(BaseVector): class OracleVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OracleVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py b/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py index aceb41289c..467dd04ef9 100644 --- a/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py +++ b/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_oracle.oraclevector import OracleVector, OracleVectorConfig from core.rag.datasource.vdb.vector_integration_test_support import ( @@ -19,6 +21,7 @@ class OracleVectorTest(AbstractVectorTest): ), ) + @override def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 diff --git a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py index 9c721c8bde..006dcd1185 100644 --- a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, override from uuid import UUID, uuid4 from numpy import ndarray @@ -62,20 +62,22 @@ class PGVectoRS(BaseVector): class _Table(CollectionORM): __tablename__ = collection_name __table_args__ = {"extend_existing": True} - id: Mapped[UUID] = mapped_column( + id: Mapped[UUID] = mapped_column( # pyrefly: ignore[missing-override-decorator] postgresql.UUID(as_uuid=True), primary_key=True, ) - text: Mapped[str] - meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB) - vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) + text: Mapped[str] # pyrefly: ignore[missing-override-decorator] + meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB) # pyrefly: ignore[missing-override-decorator] + vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) # pyrefly: ignore[missing-override-decorator] self._table = _Table self._distance_op = "<=>" + @override def get_type(self) -> str: return VectorType.PGVECTO_RS + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self.create_collection(len(embeddings[0])) self.add_texts(texts, embeddings) @@ -112,6 +114,7 @@ class PGVectoRS(BaseVector): session.execute(index_statement) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): pks = [] with sessionmaker(bind=self._client).begin() as session: @@ -129,6 +132,7 @@ class PGVectoRS(BaseVector): return pks + @override def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: @@ -139,6 +143,7 @@ class PGVectoRS(BaseVector): else: return None + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: @@ -146,6 +151,7 @@ class PGVectoRS(BaseVector): select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) + @override def delete_by_ids(self, ids: list[str]): with Session(self._client) as session: select_statement = sql_text( @@ -159,10 +165,12 @@ class PGVectoRS(BaseVector): select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) + @override def delete(self): with sessionmaker(bind=self._client).begin() as session: session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}")) + @override def text_exists(self, id: str) -> bool: with Session(self._client) as session: select_statement = sql_text( @@ -171,6 +179,7 @@ class PGVectoRS(BaseVector): result = session.execute(select_statement, {"doc_id": id}).fetchall() return len(result) > 0 + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: with Session(self._client) as session: stmt = ( @@ -201,11 +210,13 @@ class PGVectoRS(BaseVector): docs.append(doc) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] class PGVectoRSFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py index 9fc8627851..1754427811 100644 --- a/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig from core.rag.datasource.vdb.vector_integration_test_support import ( @@ -21,11 +23,13 @@ class PGVectoRSVectorTest(AbstractVectorTest): dim=128, ) + @override def search_by_full_text(self): # pgvecto rs only support english text search, So it’s not open for now hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py index b1bdce0ad4..244b288f91 100644 --- a/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py +++ b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py @@ -3,7 +3,7 @@ import json import logging import uuid from contextlib import contextmanager -from typing import Any +from typing import Any, override import psycopg2.errors import psycopg2.extras @@ -82,6 +82,7 @@ class PGVector(BaseVector): self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8] self.pg_bigm = config.pg_bigm + @override def get_type(self) -> str: return VectorType.PGVECTOR @@ -107,11 +108,13 @@ class PGVector(BaseVector): conn.commit() self.pool.putconn(conn) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) return self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): values = [] pks = [] @@ -133,6 +136,7 @@ class PGVector(BaseVector): ) return pks + @override def text_exists(self, id: str) -> bool: with self._get_cursor() as cur: cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) @@ -146,6 +150,7 @@ class PGVector(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs + @override def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios # Scenario 1: extract a document fails, resulting in a table not being created. @@ -162,10 +167,12 @@ class PGVector(BaseVector): except Exception as e: raise e + @override def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector. @@ -199,6 +206,7 @@ class PGVector(BaseVector): docs.append(Document(page_content=text, metadata=metadata)) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) if not isinstance(top_k, int) or top_k <= 0: @@ -242,6 +250,7 @@ class PGVector(BaseVector): return docs + @override def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") @@ -270,6 +279,7 @@ class PGVector(BaseVector): class PGVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py index 6b0216441b..2e6395bacc 100644 --- a/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py +++ b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, override import qdrant_client from flask import current_app @@ -90,6 +90,7 @@ class QdrantVector(BaseVector): self._distance_func = distance_func.upper() self._group_id = group_id + @override def get_type(self) -> str: return VectorType.QDRANT @@ -100,6 +101,7 @@ class QdrantVector(BaseVector): } return result + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: # get embedding vector size @@ -169,6 +171,7 @@ class QdrantVector(BaseVector): self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] @@ -251,6 +254,7 @@ class QdrantVector(BaseVector): return payloads + @override def delete_by_metadata_field(self, key: str, value: str): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -279,6 +283,7 @@ class QdrantVector(BaseVector): else: raise e + @override def delete(self): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -304,6 +309,7 @@ class QdrantVector(BaseVector): else: raise e + @override def delete_by_ids(self, ids: list[str]): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -329,6 +335,7 @@ class QdrantVector(BaseVector): else: raise e + @override def text_exists(self, id: str) -> bool: all_collection_name = [] collections_response = self._client.get_collections() @@ -341,6 +348,7 @@ class QdrantVector(BaseVector): return len(response) > 0 + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models @@ -393,6 +401,7 @@ class QdrantVector(BaseVector): docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """Return docs most similar by full-text search. @@ -488,6 +497,7 @@ class QdrantVector(BaseVector): class QdrantVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id) diff --git a/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py b/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py index ea4a3a44a3..acd3222a21 100644 --- a/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py +++ b/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py @@ -1,4 +1,5 @@ import uuid +from typing import override from dify_vdb_qdrant.qdrant_vector import QdrantConfig, QdrantVector @@ -25,6 +26,7 @@ class QdrantVectorTest(AbstractVectorTest): self.doc_banana_id = "" self.doc_both_id = "" + @override def search_by_vector(self): super().search_by_vector() # only test for qdrant, may not work on other vector stores @@ -92,6 +94,7 @@ class QdrantVectorTest(AbstractVectorTest): doc_id_list = [doc.metadata["doc_id"] for doc in hits] assert len(doc_id_list) == len(set(doc_id_list)), "Search results should not contain duplicates" + @override def run_all_tests(self): self.create_vector() self.search_by_vector() diff --git a/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py b/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py index 89ee0a47f1..011c63b804 100644 --- a/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py +++ b/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py @@ -4,6 +4,7 @@ import sys import types from collections import UserDict from types import SimpleNamespace +from typing import override from unittest.mock import MagicMock, patch import pytest @@ -68,6 +69,7 @@ def _build_fake_qdrant_modules(): self.text = text class _Distance(UserDict): + @override def __getitem__(self, key): return key diff --git a/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py index 336c2d3c8a..423f893f57 100644 --- a/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py +++ b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py @@ -1,7 +1,7 @@ import json import logging import uuid -from typing import Any +from typing import Any, override from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert @@ -64,9 +64,11 @@ class RelytVector(BaseVector): self._fields: list[str] = [] self._group_id = group_id + @override def get_type(self) -> str: return VectorType.RELYT + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) @@ -106,6 +108,7 @@ class RelytVector(BaseVector): session.execute(index_statement) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): from pgvecto_rs.sqlalchemy import VECTOR # type: ignore @@ -150,6 +153,7 @@ class RelytVector(BaseVector): return ids + @override def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self.client) as session: @@ -191,11 +195,13 @@ class RelytVector(BaseVector): logger.exception("Delete operation failed for collection %s", self._collection_name) return False + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_uuids(ids) + @override def delete_by_ids(self, ids: list[str]): with Session(self.client) as session: select_statement = sql_text( @@ -206,10 +212,12 @@ class RelytVector(BaseVector): ids = [item[0] for item in result] self.delete_by_uuids(ids) + @override def delete(self): with sessionmaker(bind=self.client).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) + @override def text_exists(self, id: str) -> bool: with Session(self.client) as session: select_statement = sql_text( @@ -218,6 +226,7 @@ class RelytVector(BaseVector): result = session.execute(select_statement, {"doc_id": id}).fetchall() return len(result) > 0 + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: document_ids_filter = kwargs.get("document_ids_filter") filter = kwargs.get("filter", {}) @@ -285,12 +294,14 @@ class RelytVector(BaseVector): ] return documents_with_scores + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # milvus/zilliz/relyt doesn't support bm25 search return [] class RelytVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py index f9deac11e5..9f1371fdd1 100644 --- a/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py +++ b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py @@ -2,7 +2,7 @@ import json import logging import math from collections.abc import Iterable -from typing import Any +from typing import Any, override import tablestore # type: ignore from pydantic import BaseModel, model_validator @@ -77,14 +77,17 @@ class TableStoreVector(BaseVector): docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=metadata)) return docs + @override def get_type(self) -> str: return VectorType.TABLESTORE + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) self.add_texts(documents=texts, embeddings=embeddings, **kwargs) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) @@ -99,6 +102,7 @@ class TableStoreVector(BaseVector): ) return uuids + @override def text_exists(self, id: str) -> bool: result = self._tablestore_client.get_row( table_name=self._table_name, primary_key=[("id", id)], columns_to_get=["id"] @@ -109,19 +113,23 @@ class TableStoreVector(BaseVector): return return_row is not None + @override def delete_by_ids(self, ids: list[str]): if not ids: return for id in ids: self._delete_row(id=id) + @override def get_ids_by_metadata_field(self, key: str, value: str): return self._search_by_metadata(key, value) + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) self.delete_by_ids(ids) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) document_ids_filter = kwargs.get("document_ids_filter") @@ -131,6 +139,7 @@ class TableStoreVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold) + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) document_ids_filter = kwargs.get("document_ids_filter") @@ -140,6 +149,7 @@ class TableStoreVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._search_by_full_text(query, filtered_list, top_k, score_threshold) + @override def delete(self): self._delete_table_if_exist() @@ -393,6 +403,7 @@ class TableStoreVector(BaseVector): class TableStoreVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py b/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py index 97c9626ee1..0885e20636 100644 --- a/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py +++ b/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py @@ -1,6 +1,7 @@ import logging import os import uuid +from typing import override import tablestore from _pytest.python_api import approx @@ -32,12 +33,14 @@ class TableStoreVectorTest(AbstractVectorTest): ), ) + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) assert ids is not None assert len(ids) == 1 assert ids[0] == self.example_doc_id + @override def create_vector(self): self.vector.create( texts=[get_example_document(doc_id=self.example_doc_id)], @@ -53,6 +56,7 @@ class TableStoreVectorTest(AbstractVectorTest): if search_response.total_count == 1: break + @override def search_by_vector(self): super().search_by_vector() docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id]) @@ -63,6 +67,7 @@ class TableStoreVectorTest(AbstractVectorTest): docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())]) assert len(docs) == 0 + @override def search_by_full_text(self): super().search_by_full_text() docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) @@ -87,6 +92,7 @@ class TableStoreVectorTest(AbstractVectorTest): docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) assert len(docs) == 0 + @override def run_all_tests(self): try: self.vector.delete() diff --git a/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py index 2f26d6fff3..dc7e7edd4a 100644 --- a/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py +++ b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, TypedDict +from typing import Any, TypedDict, override from pydantic import BaseModel from tcvdb_text.encoder import BM25Encoder # type: ignore @@ -93,6 +93,7 @@ class TencentVector(BaseVector): def _init_database(self): return self._client.create_database_if_not_exists(database_name=self._client_config.database) + @override def get_type(self) -> str: return VectorType.TENCENT @@ -181,10 +182,12 @@ class TencentVector(BaseVector): ) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self._create_collection(len(embeddings[0])) self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] @@ -215,6 +218,7 @@ class TencentVector(BaseVector): timeout=self._client_config.timeout, ) + @override def text_exists(self, id: str) -> bool: docs = self._client.query( database_name=self._client_config.database, collection_name=self.collection_name, document_ids=[id] @@ -223,6 +227,7 @@ class TencentVector(BaseVector): return True return False + @override def delete_by_ids(self, ids: list[str]): if not ids: return @@ -240,6 +245,7 @@ class TencentVector(BaseVector): database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids ) + @override def delete_by_metadata_field(self, key: str, value: str): self._client.delete( database_name=self._client_config.database, @@ -247,6 +253,7 @@ class TencentVector(BaseVector): filter=Filter(Filter.In(f"metadata.{key}", [value])), ) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: document_ids_filter = kwargs.get("document_ids_filter") filter = None @@ -265,6 +272,7 @@ class TencentVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: document_ids_filter = kwargs.get("document_ids_filter") filter = None @@ -314,6 +322,7 @@ class TencentVector(BaseVector): docs.append(doc) return docs + @override def delete(self): if self._has_collection(): self._client.drop_collection( @@ -322,6 +331,7 @@ class TencentVector(BaseVector): class TencentVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py b/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py index a53ec87f92..c5242185be 100644 --- a/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py +++ b/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py @@ -1,3 +1,4 @@ +from typing import override from unittest.mock import MagicMock from dify_vdb_tencent.tencent_vector import TencentConfig, TencentVector @@ -25,10 +26,12 @@ class TencentVectorTest(AbstractVectorTest): ), ) + @override def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 + @override def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) >= 0 diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py index abca55f540..9e6dc27203 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -4,7 +4,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override import httpx import qdrant_client @@ -91,6 +91,7 @@ class TidbOnQdrantVector(BaseVector): self._distance_func = distance_func.upper() self._group_id = group_id + @override def get_type(self) -> str: return VectorType.TIDB_ON_QDRANT @@ -101,6 +102,7 @@ class TidbOnQdrantVector(BaseVector): } return result + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: # get embedding vector size @@ -168,6 +170,7 @@ class TidbOnQdrantVector(BaseVector): self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] @@ -247,6 +250,7 @@ class TidbOnQdrantVector(BaseVector): return payloads + @override def delete_by_metadata_field(self, key: str, value: str): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -275,6 +279,7 @@ class TidbOnQdrantVector(BaseVector): else: raise e + @override def delete(self): from qdrant_client.http.exceptions import UnexpectedResponse @@ -288,6 +293,7 @@ class TidbOnQdrantVector(BaseVector): else: raise e + @override def delete_by_ids(self, ids: list[str]): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -317,6 +323,7 @@ class TidbOnQdrantVector(BaseVector): if e.status_code != 404: raise e + @override def text_exists(self, id: str) -> bool: all_collection_name = [] collections_response = self._client.get_collections() @@ -329,6 +336,7 @@ class TidbOnQdrantVector(BaseVector): return len(response) > 0 + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models @@ -370,6 +378,7 @@ class TidbOnQdrantVector(BaseVector): docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """Return docs most similar by bm25. Returns: @@ -423,6 +432,7 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id) stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) diff --git a/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py index c696a685dd..9f80ae5a76 100644 --- a/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, override import sqlalchemy from pydantic import BaseModel, model_validator @@ -46,6 +46,7 @@ class TiDBVectorConfig(BaseModel): class TiDBVector(BaseVector): + @override def get_type(self) -> str: return VectorType.TIDB_VECTOR @@ -82,6 +83,7 @@ class TiDBVector(BaseVector): self._orm_base = declarative_base() self._dimension = 1536 + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): logger.info("create collection and add texts, collection_name: %s", self._collection_name) self._create_collection(len(embeddings[0])) @@ -116,6 +118,7 @@ class TiDBVector(BaseVector): session.execute(create_statement) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): table = self._table(len(embeddings[0])) ids = self._get_uuids(documents) @@ -138,10 +141,12 @@ class TiDBVector(BaseVector): conn.execute(insert(table).values(chunks_table_data)) return ids + @override def text_exists(self, id: str) -> bool: result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) + @override def delete_by_ids(self, ids: list[str]): with Session(self._engine) as session: ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) @@ -166,6 +171,7 @@ class TiDBVector(BaseVector): logger.exception("Delete operation failed for collection %s", self._collection_name) return False + @override def get_ids_by_metadata_field(self, key: str, value: str): with Session(self._engine) as session: select_statement = sql_text( @@ -177,11 +183,13 @@ class TiDBVector(BaseVector): else: return None + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self._delete_by_ids(ids) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -231,10 +239,12 @@ class TiDBVector(BaseVector): docs.append(Document(page_content=text, metadata=metadata)) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # tidb doesn't support bm25 search return [] + @override def delete(self): with sessionmaker(bind=self._engine).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) @@ -251,6 +261,7 @@ class TiDBVector(BaseVector): class TiDBVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py index ac854acbf9..cd95156b9f 100644 --- a/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py @@ -1,3 +1,5 @@ +from typing import override + import pytest from dify_vdb_tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig @@ -25,10 +27,12 @@ class TiDBVectorTest(AbstractVectorTest): super().__init__() self.vector = vector + @override def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py index 75d70a1964..ddf51d7ca4 100644 --- a/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, override from uuid import uuid4 from pydantic import BaseModel, model_validator @@ -41,9 +41,11 @@ class UpstashVector(BaseVector): else: return 1536 + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): vectors = [ Vector( @@ -56,10 +58,12 @@ class UpstashVector(BaseVector): ] self.index.upsert(vectors=vectors) + @override def text_exists(self, id: str) -> bool: response = self.get_ids_by_metadata_field("doc_id", id) return len(response) > 0 + @override def delete_by_ids(self, ids: list[str]): item_ids = [] for doc_id in ids: @@ -72,6 +76,7 @@ class UpstashVector(BaseVector): if ids: self.index.delete(ids=ids) + @override def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: query_result = self.index.query( vector=[1.001 * i for i in range(self._get_index_dimension())], @@ -81,11 +86,13 @@ class UpstashVector(BaseVector): ) return [result.id for result in query_result] + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self._delete_by_ids(ids) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) document_ids_filter = kwargs.get("document_ids_filter") @@ -114,17 +121,21 @@ class UpstashVector(BaseVector): docs.append(Document(page_content=text, metadata=metadata)) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] + @override def delete(self): self.index.reset() + @override def get_type(self) -> str: return VectorType.UPSTASH class UpstashVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py b/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py index f4a65030b6..4f6bc6cf8b 100644 --- a/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_upstash.upstash_vector import UpstashVector, UpstashVectorConfig from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text @@ -15,10 +17,12 @@ class UpstashVectorTest(AbstractVectorTest): ), ) + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) != 0 + @override def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 diff --git a/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py index ab00f9db28..c562a75e08 100644 --- a/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py @@ -1,7 +1,7 @@ import json import uuid from contextlib import contextmanager -from typing import Any +from typing import Any, override import psycopg2.extras import psycopg2.pool @@ -69,6 +69,7 @@ class VastbaseVector(BaseVector): self.pool = self._create_connection_pool(config) self.table_name = f"embedding_{collection_name}" + @override def get_type(self) -> str: return VectorType.VASTBASE @@ -94,11 +95,13 @@ class VastbaseVector(BaseVector): conn.commit() self.pool.putconn(conn) + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) return self.add_texts(texts, embeddings) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): values = [] pks = [] @@ -120,6 +123,7 @@ class VastbaseVector(BaseVector): ) return pks + @override def text_exists(self, id: str) -> bool: with self._get_cursor() as cur: cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) @@ -133,6 +137,7 @@ class VastbaseVector(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs + @override def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios # Scenario 1: extract a document fails, resulting in a table not being created. @@ -142,10 +147,12 @@ class VastbaseVector(BaseVector): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + @override def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector. @@ -174,6 +181,7 @@ class VastbaseVector(BaseVector): docs.append(Document(page_content=text, metadata=metadata)) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) @@ -199,6 +207,7 @@ class VastbaseVector(BaseVector): return docs + @override def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") @@ -220,6 +229,7 @@ class VastbaseVector(BaseVector): class VastbaseVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py b/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py index 83fd3626d9..bc5e542cf0 100644 --- a/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py +++ b/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, override from pydantic import BaseModel from volcengine.viking_db import ( # type: ignore @@ -106,14 +106,17 @@ class VikingDBVector(BaseVector): ) redis_client.set(collection_exist_cache_key, 1, ex=3600) + @override def get_type(self) -> str: return VectorType.VIKINGDB + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) self.add_texts(texts, embeddings, **kwargs) + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): page_contents = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] @@ -138,6 +141,7 @@ class VikingDBVector(BaseVector): self._client.get_collection(self._collection_name).upsert_data(docs) + @override def text_exists(self, id: str) -> bool: docs = self._client.get_collection(self._collection_name).fetch_data(id) not_exists_str = "data does not exist" @@ -145,9 +149,11 @@ class VikingDBVector(BaseVector): return True return False + @override def delete_by_ids(self, ids: list[str]): self._client.get_collection(self._collection_name).delete_data(ids) + @override def get_ids_by_metadata_field(self, key: str, value: str): # Note: Metadata field value is an dict, but vikingdb field # not support json type @@ -169,10 +175,12 @@ class VikingDBVector(BaseVector): ids.append(result.id) return ids + @override def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) self.delete_by_ids(ids) + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self._client.get_index(self._collection_name, self._index_name).search_by_vector( query_vector, limit=kwargs.get("top_k", 4) @@ -198,9 +206,11 @@ class VikingDBVector(BaseVector): docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] + @override def delete(self): if self._has_index(): self._client.drop_index(self._collection_name, self._index_name) @@ -209,6 +219,7 @@ class VikingDBVector(BaseVector): class VikingDBVectorFactory(AbstractVectorFactory): + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] diff --git a/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py index 5a3908d14b..28bc00b906 100644 --- a/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py +++ b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py @@ -1,3 +1,5 @@ +from typing import override + from dify_vdb_vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text @@ -20,14 +22,17 @@ class VikingDBVectorTest(AbstractVectorTest): ), ) + @override def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 + @override def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 + @override def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id") assert len(ids) > 0 diff --git a/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py index 902e6a03a8..00824831eb 100644 --- a/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py @@ -11,7 +11,7 @@ import json import logging import threading import uuid as _uuid -from typing import Any +from typing import Any, override from urllib.parse import urlparse import weaviate @@ -165,6 +165,7 @@ class WeaviateVector(BaseVector): _weaviate_client = client return client + @override def get_type(self) -> str: """Returns the vector database type identifier.""" return VectorType.WEAVIATE @@ -192,6 +193,7 @@ class WeaviateVector(BaseVector): } return result + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """ Creates a new collection and adds initial documents with embeddings. @@ -275,6 +277,7 @@ class WeaviateVector(BaseVector): except Exception as e: logger.warning("Could not add property %s: %s", prop.name, e) + @override def _get_uuids(self, documents: list[Document]) -> list[str]: """ Generates deterministic UUIDs for documents based on their content. @@ -290,6 +293,7 @@ class WeaviateVector(BaseVector): return uuids + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): """ Adds documents with their embeddings to the collection. @@ -340,6 +344,7 @@ class WeaviateVector(BaseVector): except Exception: return False + @override def delete_by_metadata_field(self, key: str, value: str) -> None: """Deletes all objects matching a specific metadata field value.""" if not self._client.collections.exists(self._collection_name): @@ -348,11 +353,13 @@ class WeaviateVector(BaseVector): col = self._client.collections.use(self._collection_name) col.data.delete_many(where=Filter.by_property(key).equal(value)) + @override def delete(self): """Deletes the entire collection from Weaviate.""" if self._client.collections.exists(self._collection_name): self._client.collections.delete(self._collection_name) + @override def text_exists(self, id: str) -> bool: """Checks if a document with the given doc_id exists in the collection.""" if not self._client.collections.exists(self._collection_name): @@ -367,6 +374,7 @@ class WeaviateVector(BaseVector): return len(res.objects) > 0 + @override def delete_by_ids(self, ids: list[str]) -> None: """ Deletes objects by their UUID identifiers. @@ -385,6 +393,7 @@ class WeaviateVector(BaseVector): if getattr(e, "status_code", None) != 404: raise + @override def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Performs vector similarity search using the provided query vector. @@ -445,6 +454,7 @@ class WeaviateVector(BaseVector): docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True) return docs + @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: """ Performs BM25 full-text search on document content. @@ -506,6 +516,7 @@ class WeaviateVector(BaseVector): class WeaviateVectorFactory(AbstractVectorFactory): """Factory class for creating WeaviateVector instances.""" + @override def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: """ Initializes a WeaviateVector instance for the given dataset. diff --git a/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py index b40f7e52ca..8aea8634fb 100644 --- a/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py @@ -1,3 +1,5 @@ +from typing import override + """Unit tests for Weaviate vector database implementation. Focuses on verifying that doc_type is properly handled in: @@ -23,6 +25,7 @@ from core.rag.models.document import Document class TestWeaviateVector(unittest.TestCase): """Tests for WeaviateVector class with focus on doc_type metadata handling.""" + @override def setUp(self): weaviate_vector_module._weaviate_client = None self.config = WeaviateConfig( @@ -33,6 +36,7 @@ class TestWeaviateVector(unittest.TestCase): self.collection_name = "Test_Collection_Node" self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + @override def tearDown(self): weaviate_vector_module._weaviate_client = None diff --git a/api/pyproject.toml b/api/pyproject.toml index 879e86a602..3fa42b9f5a 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -292,3 +292,4 @@ python-platform = "linux" python-version = "3.12.0" infer-with-first-use = true min-severity = "warn" +errors = { missing-override-decorator = "error" } diff --git a/api/services/api_token_service.py b/api/services/api_token_service.py index 98cb5c0620..e9d887195a 100644 --- a/api/services/api_token_service.py +++ b/api/services/api_token_service.py @@ -7,7 +7,7 @@ Includes Redis cache operations, database queries, and single-flight concurrency import logging from datetime import datetime -from typing import Any +from typing import Any, override from pydantic import BaseModel from sqlalchemy import select @@ -43,6 +43,7 @@ class CachedApiToken(BaseModel): last_used_at: datetime | None created_at: datetime | None + @override def __repr__(self) -> str: return f"" diff --git a/api/services/app_service.py b/api/services/app_service.py index f288b04f3b..58727e658c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Sequence -from typing import Any, Literal, TypedDict, cast +from typing import Any, Literal, TypedDict, cast, override import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination @@ -360,6 +360,7 @@ class AppService: self.__dict__.update(app.__dict__) @property + @override def app_model_config(self): return model_config diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index c9e5610aea..bfced19128 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -1,4 +1,5 @@ import json +from typing import override import httpx @@ -17,6 +18,7 @@ class FirecrawlAuth(ApiKeyAuthBase): if not self.api_key: raise ValueError("No API key provided") + @override def validate_credentials(self): headers = self._prepare_headers() options = { diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index e63c9a3a4d..0b49b88e98 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -1,4 +1,5 @@ import json +from typing import override import httpx @@ -22,6 +23,7 @@ class JinaAuth(ApiKeyAuthBase): if not self.api_key: raise ValueError("No API key provided") + @override def validate_credentials(self): headers = self._prepare_headers() options = { diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index 8ea0b6cd69..d8d2fd51c0 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -1,4 +1,5 @@ import json +from typing import override import httpx @@ -22,6 +23,7 @@ class JinaAuth(ApiKeyAuthBase): if not self.api_key: raise ValueError("No API key provided") + @override def validate_credentials(self): headers = self._prepare_headers() options = { diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index cbdc908690..d07c2cc318 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -1,4 +1,5 @@ import json +from typing import override from urllib.parse import urljoin import httpx @@ -18,6 +19,7 @@ class WatercrawlAuth(ApiKeyAuthBase): if not self.api_key: raise ValueError("No API key provided") + @override def validate_credentials(self): headers = self._prepare_headers() url = urljoin(self.base_url, "/api/v1/core/crawl-requests/") diff --git a/api/services/document_indexing_proxy/batch_indexing_base.py b/api/services/document_indexing_proxy/batch_indexing_base.py index dd122f34a8..7631e9ca8f 100644 --- a/api/services/document_indexing_proxy/batch_indexing_base.py +++ b/api/services/document_indexing_proxy/batch_indexing_base.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable, Sequence from dataclasses import asdict -from typing import Any +from typing import Any, override from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue @@ -33,6 +33,7 @@ class BatchDocumentIndexingProxy(DocumentTaskProxyBase): self._document_ids = document_ids self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME) + @override def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): """ Send batch task to direct queue. @@ -45,6 +46,7 @@ class BatchDocumentIndexingProxy(DocumentTaskProxyBase): tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids ) + @override def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): """ Send batch task to tenant-isolated queue. diff --git a/api/services/document_indexing_proxy/document_indexing_task_proxy.py b/api/services/document_indexing_proxy/document_indexing_task_proxy.py index fce79a8387..d9295899cb 100644 --- a/api/services/document_indexing_proxy/document_indexing_task_proxy.py +++ b/api/services/document_indexing_proxy/document_indexing_task_proxy.py @@ -8,5 +8,5 @@ class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy): """Proxy for document indexing tasks.""" QUEUE_NAME: ClassVar[str] = "document_indexing" - NORMAL_TASK_FUNC = normal_document_indexing_task - PRIORITY_TASK_FUNC = priority_document_indexing_task + NORMAL_TASK_FUNC = normal_document_indexing_task # pyrefly: ignore[missing-override-decorator] + PRIORITY_TASK_FUNC = priority_document_indexing_task # pyrefly: ignore[missing-override-decorator] diff --git a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py index 277cfbdcf1..224cab1e14 100644 --- a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py +++ b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py @@ -11,5 +11,5 @@ class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy): """Proxy for duplicate document indexing tasks.""" QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing" - NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task - PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task + NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator] + PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator] diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index 23571f2d7d..d423181ff6 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -1,4 +1,5 @@ import logging +from typing import override from pydantic import BaseModel @@ -15,6 +16,7 @@ class CheckCredentialPolicyComplianceRequest(BaseModel): provider: str credential_type: PluginCredentialType + @override def model_dump(self, **kwargs): data = super().model_dump(**kwargs) data["credential_type"] = self.credential_type.to_number() diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py index 407779d795..159468d271 100644 --- a/api/services/errors/llm.py +++ b/api/services/errors/llm.py @@ -1,3 +1,6 @@ +from typing import override + + class InvokeError(Exception): """Base class for all LLM exceptions.""" @@ -6,6 +9,7 @@ class InvokeError(Exception): def __init__(self, description: str = ""): self.description = description + @override def __str__(self): return self.description or self.__class__.__name__ diff --git a/api/services/legacy_model_type_migration.py b/api/services/legacy_model_type_migration.py index 2de5e7f7f3..1465fc0912 100644 --- a/api/services/legacy_model_type_migration.py +++ b/api/services/legacy_model_type_migration.py @@ -34,7 +34,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass from datetime import datetime from enum import IntEnum, StrEnum -from typing import Protocol, cast +from typing import Protocol, cast, override import sqlalchemy as sa from sqlalchemy.exc import OperationalError @@ -337,9 +337,11 @@ class _ThreadSafeLineWriter(io.TextIOBase): self._lock = threading.Lock() self._local = threading.local() + @override def writable(self) -> bool: return True + @override def write(self, text: str) -> int: if not text: return 0 @@ -356,6 +358,7 @@ class _ThreadSafeLineWriter(io.TextIOBase): self._buffer = remainder return len(text) + @override def flush(self) -> None: buffered_text = self._buffer if buffered_text: diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index 8c9a81af87..3ba7593be5 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -1,7 +1,7 @@ import json from os import path from pathlib import Path -from typing import Any +from typing import Any, override from flask import current_app @@ -16,13 +16,16 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): builtin_data: dict[str, Any] | None = None + @override def get_type(self) -> str: return PipelineTemplateType.BUILTIN + @override def get_pipeline_templates(self, language: str) -> dict[str, Any]: result = self.fetch_pipeline_templates_from_builtin(language) return result + @override def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: result = self.fetch_pipeline_template_detail_from_builtin(template_id) return result diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 9d446f6d4b..ee73b0328f 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any, TypedDict +from typing import Any, TypedDict, override import yaml from sqlalchemy import select @@ -39,13 +39,16 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from database """ + @override def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) + @override def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: return self.fetch_pipeline_template_detail_from_db(template_id) + @override def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 2964537c35..9c94fdee2b 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any, TypedDict +from typing import Any, TypedDict, override import yaml from sqlalchemy import select @@ -39,12 +39,15 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval pipeline template from database """ + @override def get_pipeline_templates(self, language: str) -> dict[str, Any]: return self.fetch_pipeline_templates_from_db(language) + @override def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: return self.fetch_pipeline_template_detail_from_db(template_id) + @override def get_type(self) -> str: return PipelineTemplateType.DATABASE diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 9565ac46cc..1be97c2888 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, override import httpx @@ -16,6 +16,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ + @override def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: try: return self.fetch_pipeline_template_detail_from_dify_official(template_id) @@ -23,6 +24,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) + @override def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: return self.fetch_pipeline_templates_from_dify_official(language) @@ -30,6 +32,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) + @override def get_type(self) -> str: return PipelineTemplateType.REMOTE diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 16dc66cd76..e48286303c 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -1,7 +1,7 @@ import json from os import path from pathlib import Path -from typing import Any +from typing import Any, override from flask import current_app @@ -16,13 +16,16 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): builtin_data: dict[str, Any] | None = None + @override def get_type(self) -> str: return RecommendAppType.BUILDIN + @override def get_recommended_apps_and_categories(self, language: str): result = self.fetch_recommended_apps_from_builtin(language) return result + @override def get_recommend_app_detail(self, app_id: str): result = self.fetch_recommended_app_detail_from_builtin(app_id) return result diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index ac870f0700..d420b33930 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any, TypedDict +from typing import Any, TypedDict, override from sqlalchemy import select @@ -43,14 +43,17 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from database """ + @override def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict: result = self.fetch_recommended_apps_from_db(language) return result + @override def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None: result = self.fetch_recommended_app_detail_from_db(app_id) return result + @override def get_type(self) -> str: return RecommendAppType.DATABASE diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 0603e4c482..890fb132fa 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, override import httpx @@ -19,6 +19,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): Keep the response order intact so Explore matches the template service. """ + @override def get_recommend_app_detail(self, app_id: str): try: result = self.fetch_recommended_app_detail_from_dify_official(app_id) @@ -27,6 +28,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id) return result + @override def get_recommended_apps_and_categories(self, language: str): try: result = self.fetch_recommended_apps_from_dify_official(language) @@ -35,6 +37,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) return result + @override def get_type(self) -> str: return RecommendAppType.REMOTE diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py index 6e647b983b..23d9231e0b 100644 --- a/api/services/retention/conversation/messages_clean_policy.py +++ b/api/services/retention/conversation/messages_clean_policy.py @@ -3,6 +3,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass +from typing import override from configs import dify_config from enums.cloud_plan import CloudPlan @@ -51,6 +52,7 @@ class BillingDisabledPolicy(MessagesCleanPolicy): No special filter logic, just return all message ids. """ + @override def filter_message_ids( self, messages: Sequence[SimpleMessage], @@ -82,6 +84,7 @@ class BillingSandboxPolicy(MessagesCleanPolicy): self._plan_provider = plan_provider self._current_timestamp = current_timestamp + @override def filter_message_ids( self, messages: Sequence[SimpleMessage], diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 5dd5f6873f..93a032498d 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any, overload +from typing import Any, overload, override from configs import dify_config from graphon.file import File @@ -112,6 +112,7 @@ class VariableTruncator(BaseTruncator): string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH, ) + @override def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]: """ `truncate_variable_mapping` is responsible for truncating variable mappings @@ -157,6 +158,7 @@ class VariableTruncator(BaseTruncator): return False return True + @override def truncate(self, segment: Segment) -> TruncationResult: if isinstance(segment, StringSegment): result = self._truncate_segment(segment, self._string_length_limit) @@ -448,6 +450,7 @@ class DummyVariableTruncator(BaseTruncator): to maintain backward compatibility and provide complete data. """ + @override def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]: """ Return original mapping without truncation. @@ -460,6 +463,7 @@ class DummyVariableTruncator(BaseTruncator): """ return v, False + @override def truncate(self, segment: Segment) -> TruncationResult: """ Return original segment without truncation. diff --git a/api/services/workflow/queue_dispatcher.py b/api/services/workflow/queue_dispatcher.py index cc366482c8..1a79958cc2 100644 --- a/api/services/workflow/queue_dispatcher.py +++ b/api/services/workflow/queue_dispatcher.py @@ -1,3 +1,5 @@ +from typing import override + """ Queue dispatcher system for async workflow execution. @@ -37,9 +39,11 @@ class BaseQueueDispatcher(ABC): class ProfessionalQueueDispatcher(BaseQueueDispatcher): """Dispatcher for professional tier""" + @override def get_queue_name(self) -> str: return QueuePriority.PROFESSIONAL + @override def get_priority(self) -> int: return 100 @@ -47,9 +51,11 @@ class ProfessionalQueueDispatcher(BaseQueueDispatcher): class TeamQueueDispatcher(BaseQueueDispatcher): """Dispatcher for team tier""" + @override def get_queue_name(self) -> str: return QueuePriority.TEAM + @override def get_priority(self) -> int: return 50 @@ -57,9 +63,11 @@ class TeamQueueDispatcher(BaseQueueDispatcher): class SandboxQueueDispatcher(BaseQueueDispatcher): """Dispatcher for free/sandbox tier""" + @override def get_queue_name(self) -> str: return QueuePriority.SANDBOX + @override def get_priority(self) -> int: return 10 diff --git a/api/services/workflow_collaboration_service.py b/api/services/workflow_collaboration_service.py index cf2f509052..80bf51284b 100644 --- a/api/services/workflow_collaboration_service.py +++ b/api/services/workflow_collaboration_service.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import time from collections.abc import Mapping +from typing import override from sqlalchemy import select @@ -19,6 +20,7 @@ class WorkflowCollaborationService: self._repository = repository self._socketio = socketio + @override def __repr__(self) -> str: return f"{self.__class__.__name__}(repository={self._repository})" diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 5151b7a08c..d7351c5fa5 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence, Set from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import StrEnum -from typing import Any, ClassVar, NotRequired, TypedDict, cast +from typing import Any, ClassVar, NotRequired, TypedDict, cast, override from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert @@ -107,6 +107,7 @@ class DraftVarLoader(VariableLoader): def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: return (selector[0], selector[1]) + @override def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: if not selectors: return [] diff --git a/api/tasks/mail_inner_task.py b/api/tasks/mail_inner_task.py index 294f6c3e25..d4d1ae67ce 100644 --- a/api/tasks/mail_inner_task.py +++ b/api/tasks/mail_inner_task.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Mapping -from typing import Any +from typing import Any, override import click from celery import shared_task @@ -22,6 +22,7 @@ class SandboxedEnvironment(ImmutableSandboxedEnvironment): self._timeout_time = time.time() + timeout super().__init__(*args, **kwargs) + @override def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any: if time.time() > self._timeout_time: raise TimeoutError("Template rendering timeout") diff --git a/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py b/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py index 972501290b..dcf16f462e 100644 --- a/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py +++ b/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py @@ -1,3 +1,5 @@ +from typing import override + from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue @@ -16,6 +18,7 @@ class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler[AsyncWorkflowCFSPlanEntity] Trigger workflow CFS plan scheduler. """ + @override def can_schedule(self) -> SchedulerCommand: """ Check if the workflow can be scheduled. diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index e130644338..737f3ff580 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -1,6 +1,7 @@ import json import unittest import uuid +from typing import override import pytest from sqlalchemy import delete, func, select @@ -36,6 +37,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): _node2_id = "test_node_2" _node_exec_id = str(uuid.uuid4()) + @override def setUp(self): self._test_app_id = str(uuid.uuid4()) self._test_user_id = str(uuid.uuid4()) @@ -102,6 +104,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def _get_test_srv(self) -> WorkflowDraftVariableService: return WorkflowDraftVariableService(session=self._session) + @override def tearDown(self): self._session.rollback() @@ -213,7 +216,7 @@ class TestDraftVariableLoader(unittest.TestCase): # @pytest.fixture # def node_var(self, session): # pass - + @override def setUp(self): self._test_app_id = str(uuid.uuid4()) self._test_tenant_id = str(uuid.uuid4()) @@ -255,6 +258,7 @@ class TestDraftVariableLoader(unittest.TestCase): self._sys_var_id = sys_var.id self._conv_var_id = conv_var.id + @override def tearDown(self): with Session(bind=db.engine, expire_on_commit=False) as session: session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id)) @@ -568,6 +572,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): _node_exec_id: str _workflow_node_exec_id: str + @override def setUp(self): self._test_app_id = str(uuid.uuid4()) self._test_tenant_id = str(uuid.uuid4()) @@ -676,6 +681,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): self._node_var_missing_exec_id = self._node_var_missing_exec.id self._conv_var_id = self._conv_var.id + @override def tearDown(self): self._session.rollback() with Session(db.engine) as session, session.begin(): diff --git a/api/tests/unit_tests/clients/agent_backend/test_client.py b/api/tests/unit_tests/clients/agent_backend/test_client.py index 407372d29d..19a1d41af5 100644 --- a/api/tests/unit_tests/clients/agent_backend/test_client.py +++ b/api/tests/unit_tests/clients/agent_backend/test_client.py @@ -1,4 +1,5 @@ from collections.abc import Iterator +from typing import override import pytest from dify_agent.client import DifyAgentHTTPError, DifyAgentStreamError, DifyAgentTimeoutError, DifyAgentValidationError @@ -81,6 +82,7 @@ def test_dify_agent_backend_run_client_delegates_sync_methods(): def test_dify_agent_backend_run_client_maps_validation_error(): class InvalidClient(_SuccessfulClient): + @override def create_run_sync(self, request: CreateRunRequest) -> CreateRunResponse: raise DifyAgentValidationError(detail={"field": "bad"}) @@ -92,6 +94,7 @@ def test_dify_agent_backend_run_client_maps_validation_error(): def test_dify_agent_backend_run_client_maps_http_error(): class HTTPErrorClient(_SuccessfulClient): + @override def create_run_sync(self, request: CreateRunRequest) -> CreateRunResponse: raise DifyAgentHTTPError(status_code=503, detail="unavailable") @@ -104,6 +107,7 @@ def test_dify_agent_backend_run_client_maps_http_error(): def test_dify_agent_backend_run_client_maps_timeout_error(): class TimeoutClient(_SuccessfulClient): + @override def wait_run_sync(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse: raise DifyAgentTimeoutError("timeout") @@ -115,6 +119,7 @@ def test_dify_agent_backend_run_client_maps_timeout_error(): def test_dify_agent_backend_run_client_maps_stream_error(): class StreamClient(_SuccessfulClient): + @override def stream_events_sync(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]: raise DifyAgentStreamError("bad stream") yield diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py index b2920f93d3..de52c62fdd 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -8,6 +8,7 @@ from collections.abc import Sequence from dataclasses import dataclass from datetime import UTC, datetime from types import SimpleNamespace +from typing import override from unittest.mock import ANY, MagicMock, Mock import pytest @@ -156,24 +157,30 @@ class _FakePauseEntity(WorkflowPauseEntity): pause_reasons: Sequence[HumanInputRequired] @property + @override def id(self) -> str: return self.pause_id @property + @override def workflow_execution_id(self) -> str: return self.workflow_run_id + @override def get_state(self) -> bytes: raise AssertionError("state is not required for snapshot tests") @property + @override def resumed_at(self) -> datetime | None: return None @property + @override def paused_at(self) -> datetime: return self.paused_at_value + @override def get_pause_reasons(self) -> Sequence[HumanInputRequired]: return self.pause_reasons diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py index 9a695da9e4..acef632198 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py @@ -5,7 +5,7 @@ saved, using the deterministic fake backend client (no live stack).""" from __future__ import annotations from types import SimpleNamespace -from typing import Any +from typing import Any, override import pytest from agenton.compositor import CompositorSessionSnapshot @@ -48,6 +48,7 @@ class _FakeQueueManager: class _StoppedQueueManager(_FakeQueueManager): + @override def is_stopped(self) -> bool: return True @@ -57,6 +58,7 @@ class _RecordingFakeAgentBackendRunClient(FakeAgentBackendRunClient): super().__init__(**kwargs) self.cancelled_run_ids: list[str] = [] + @override def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse: self.cancelled_run_ids.append(run_id) return super().cancel_run(run_id, request=request) diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py index 63e887f904..5c932639b9 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_base.py +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -1,4 +1,5 @@ -from typing import Any +from collections.abc import Mapping +from typing import Any, override import pytest @@ -14,10 +15,12 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict[str, Any], query: str | None = None) -> str: + @override + def query(self, inputs: Mapping[str, Any], query: str | None = None): return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) @@ -30,10 +33,12 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict[str, Any], query: str | None = None) -> str: + @override + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: return "" tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") @@ -45,10 +50,12 @@ class TestExternalDataTool: def test_validate_config_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict[str, Any], query: str | None = None) -> str: + @override + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: return "" with pytest.raises(NotImplementedError): @@ -57,10 +64,12 @@ class TestExternalDataTool: def test_query_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict[str, Any], query: str | None = None) -> str: + @override + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py index 55e22aea0a..12855ed564 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import override import pytest @@ -7,21 +8,27 @@ from core.rag.models.document import Document class _KeywordThatRaises(BaseKeyword): + @override def create(self, texts: list[Document], **kwargs): return super().create(texts, **kwargs) + @override def add_texts(self, texts: list[Document], **kwargs): return super().add_texts(texts, **kwargs) + @override def text_exists(self, id: str) -> bool: return super().text_exists(id) + @override def delete_by_ids(self, ids: list[str]): return super().delete_by_ids(ids) + @override def delete(self): return super().delete() + @override def search(self, query: str, **kwargs): return super().search(query, **kwargs) @@ -31,21 +38,27 @@ class _KeywordForHelpers(BaseKeyword): super().__init__(dataset) self._existing_ids = existing_ids or set() + @override def create(self, texts: list[Document], **kwargs): return self + @override def add_texts(self, texts: list[Document], **kwargs): return None + @override def text_exists(self, id: str) -> bool: return id in self._existing_ids + @override def delete_by_ids(self, ids: list[str]): return None + @override def delete(self): return None + @override def search(self, query: str, **kwargs): return [] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py index 369cda39bf..fdbdd1833e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import override import pytest @@ -11,30 +12,39 @@ class _DummyVector(BaseVector): super().__init__(collection_name) self._existing_ids = existing_ids or set() + @override def get_type(self) -> str: return "dummy" + @override def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): return None + @override def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): return None + @override def text_exists(self, id: str) -> bool: return id in self._existing_ids + @override def delete_by_ids(self, ids: list[str]): return None + @override def delete_by_metadata_field(self, key: str, value: str): return None + @override def search_by_vector(self, query_vector: list[float], **kwargs): return [] + @override def search_by_full_text(self, query: str, **kwargs): return [] + @override def delete(self): return None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py index 033933e886..493a6f5374 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py @@ -2,7 +2,7 @@ import asyncio import inspect -from typing import Any +from typing import Any, override import pytest @@ -12,15 +12,19 @@ from core.rag.embedding.embedding_base import Embeddings class ConcreteEmbeddings(Embeddings): """Concrete implementation of Embeddings for testing.""" + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: return [[1.0] * 10 for _ in texts] + @override def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: return [[1.0] * 10 for _ in multimodel_documents] + @override def embed_query(self, text: str) -> list[float]: return [1.0] * 10 + @override def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: return [1.0] * 10 diff --git a/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py b/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py index eb14622d7a..2ce0ed1445 100644 --- a/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py +++ b/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py @@ -1,4 +1,5 @@ from io import BytesIO +from pathlib import Path import pytest @@ -10,7 +11,7 @@ class TestBlob: with pytest.raises(ValueError, match="Either data or path must be provided"): Blob() - def test_source_property_and_repr_include_path(self, tmp_path): + def test_source_property_and_repr_include_path(self, tmp_path: Path): file_path = tmp_path / "sample.txt" file_path.write_text("hello", encoding="utf-8") @@ -23,7 +24,7 @@ class TestBlob: assert Blob.from_data(b"abc").as_string() == "abc" assert Blob.from_data("plain-text").as_string() == "plain-text" - def test_as_string_from_path(self, tmp_path): + def test_as_string_from_path(self, tmp_path: Path): file_path = tmp_path / "sample.txt" file_path.write_text("from-file", encoding="utf-8") @@ -37,7 +38,7 @@ class TestBlob: with pytest.raises(ValueError, match="Unable to get string for blob"): blob.as_string() - def test_as_bytes_from_bytes_str_and_path(self, tmp_path): + def test_as_bytes_from_bytes_str_and_path(self, tmp_path: Path): from_bytes = Blob.from_data(b"abc") from_str = Blob.from_data("abc", encoding="utf-8") @@ -55,7 +56,7 @@ class TestBlob: with pytest.raises(ValueError, match="Unable to get bytes for blob"): blob.as_bytes() - def test_as_bytes_io_for_bytes_and_path(self, tmp_path): + def test_as_bytes_io_for_bytes_and_path(self, tmp_path: Path): data_blob = Blob.from_data(b"bytes-io") with data_blob.as_bytes_io() as stream: assert isinstance(stream, BytesIO) @@ -74,7 +75,7 @@ class TestBlob: with blob.as_bytes_io(): pass - def test_from_path_respects_guessing_and_explicit_mime(self, tmp_path): + def test_from_path_respects_guessing_and_explicit_mime(self, tmp_path: Path): file_path = tmp_path / "example.txt" file_path.write_text("x", encoding="utf-8") diff --git a/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py index 2e1c5715c2..f82ebb4d75 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py @@ -1,6 +1,8 @@ import csv import io +from pathlib import Path from types import SimpleNamespace +from typing import override import pandas as pd import pytest @@ -10,16 +12,17 @@ from core.rag.extractor.csv_extractor import CSVExtractor class _ManagedStringIO(io.StringIO): + @override def __enter__(self): return self - def __exit__(self, exc_type, exc, tb): + @override + def __exit__(self, exc_type, exc_val, exc_tb): self.close() - return False class TestCSVExtractor: - def test_extract_success_with_source_column(self, tmp_path): + def test_extract_success_with_source_column(self, tmp_path: Path): file_path = tmp_path / "data.csv" file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8") @@ -30,7 +33,7 @@ class TestCSVExtractor: assert docs[0].page_content == "id: source-1;body: hello" assert docs[0].metadata == {"source": "source-1", "row": 0} - def test_extract_raises_when_source_column_missing(self, tmp_path): + def test_extract_raises_when_source_column_missing(self, tmp_path: Path): file_path = tmp_path / "data.csv" file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8") diff --git a/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py index 8bc65e5654..c1c675e0d4 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py @@ -1,8 +1,10 @@ +from pathlib import Path + from core.rag.extractor.html_extractor import HtmlExtractor class TestHtmlExtractor: - def test_extract_returns_text_content(self, tmp_path): + def test_extract_returns_text_content(self, tmp_path: Path): file_path = tmp_path / "sample.html" file_path.write_text("

Title

Hello

", encoding="utf-8") @@ -12,7 +14,7 @@ class TestHtmlExtractor: assert len(docs) == 1 assert "".join(docs[0].page_content.split()) == "TitleHello" - def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path): + def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path: Path): file_path = tmp_path / "sample.html" file_path.write_text(" \n ", encoding="utf-8") diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py index 8ede44ec04..2ca4c3b0e0 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -56,7 +56,7 @@ after assert extractor.remove_images(with_images) == "before after" assert extractor.remove_hyperlinks(with_links) == "OpenAI" - def test_parse_tups_reads_file_and_applies_options(self, tmp_path): + def test_parse_tups_reads_file_and_applies_options(self, tmp_path: Path): markdown_file = tmp_path / "doc.md" markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8") diff --git a/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py index 71046d73af..908bf2a908 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py @@ -8,7 +8,7 @@ from core.rag.extractor.text_extractor import TextExtractor class TestTextExtractor: - def test_extract_success(self, tmp_path): + def test_extract_success(self, tmp_path: Path): file_path = tmp_path / "data.txt" file_path.write_text("hello world", encoding="utf-8") diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 6bb53976fa..40885cfed2 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -6,6 +6,7 @@ import tempfile from collections import UserDict from pathlib import Path from types import SimpleNamespace +from typing import override from unittest.mock import MagicMock import pytest @@ -339,7 +340,7 @@ def test_init_rejects_invalid_url_status(monkeypatch: pytest.MonkeyPatch): assert fake_response.closed is True -def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): +def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path: Path): target_file = tmp_path / "expanded.docx" target_file.write_bytes(b"docx") @@ -517,6 +518,7 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke shape_internal_part = object() class Rels(UserDict): + @override def get(self, key, default=None): if key == "link-bad": raise RuntimeError("cannot resolve relation") diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py index f484e2b25b..21118cc688 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import override from unittest.mock import Mock, patch import httpx @@ -8,15 +9,19 @@ from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document +from core.rag.retrieval.retrieval_methods import RetrievalMethod class _ForwardingBaseIndexProcessor(BaseIndexProcessor): + @override def extract(self, extract_setting, **kwargs): return super().extract(extract_setting, **kwargs) + @override def transform(self, documents, current_user=None, **kwargs): return super().transform(documents, current_user=current_user, **kwargs) + @override def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None): return super().generate_summary_preview( tenant_id=tenant_id, @@ -25,6 +30,7 @@ class _ForwardingBaseIndexProcessor(BaseIndexProcessor): doc_language=doc_language, ) + @override def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs): return super().load( dataset=dataset, @@ -34,15 +40,19 @@ class _ForwardingBaseIndexProcessor(BaseIndexProcessor): **kwargs, ) + @override def clean(self, dataset, node_ids, with_keywords=True, **kwargs): return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs) + @override def index(self, dataset, document, chunks): return super().index(dataset=dataset, document=document, chunks=chunks) + @override def format_preview(self, chunks): return super().format_preview(chunks) + @override def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model): return super().retrieve( retrieval_method=retrieval_method, @@ -75,7 +85,7 @@ class TestBaseIndexProcessor: with pytest.raises(NotImplementedError): processor.format_preview([]) with pytest.raises(NotImplementedError): - processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {}) + processor.retrieve(RetrievalMethod.SEMANTIC_SEARCH, "q", Mock(), 3, 0.5, {}) def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None: with patch( diff --git a/api/tests/unit_tests/core/schemas/test_registry.py b/api/tests/unit_tests/core/schemas/test_registry.py index 5749e72eb0..6174b66f47 100644 --- a/api/tests/unit_tests/core/schemas/test_registry.py +++ b/api/tests/unit_tests/core/schemas/test_registry.py @@ -1,11 +1,12 @@ import json +from pathlib import Path from unittest.mock import patch from core.schemas.registry import SchemaRegistry class TestSchemaRegistry: - def test_initialization(self, tmp_path): + def test_initialization(self, tmp_path: Path): base_dir = tmp_path / "schemas" base_dir.mkdir() registry = SchemaRegistry(str(base_dir)) @@ -19,13 +20,13 @@ class TestSchemaRegistry: assert registry1 is registry2 assert isinstance(registry1, SchemaRegistry) - def test_load_all_versions_non_existent_dir(self, tmp_path): + def test_load_all_versions_non_existent_dir(self, tmp_path: Path): base_dir = tmp_path / "non_existent" registry = SchemaRegistry(str(base_dir)) registry.load_all_versions() assert registry.versions == {} - def test_load_all_versions_filtering(self, tmp_path): + def test_load_all_versions_filtering(self, tmp_path: Path): base_dir = tmp_path / "schemas" base_dir.mkdir() (base_dir / "not_a_version_dir").mkdir() @@ -38,7 +39,7 @@ class TestSchemaRegistry: mock_load.assert_called_once() assert mock_load.call_args[0][0] == "v1" - def test_load_version_dir_filtering(self, tmp_path): + def test_load_version_dir_filtering(self, tmp_path: Path): version_dir = tmp_path / "v1" version_dir.mkdir() (version_dir / "schema1.json").write_text("{}") @@ -50,13 +51,13 @@ class TestSchemaRegistry: mock_load.assert_called_once() assert mock_load.call_args[0][1] == "schema1" - def test_load_version_dir_non_existent(self, tmp_path): + def test_load_version_dir_non_existent(self, tmp_path: Path): version_dir = tmp_path / "non_existent" registry = SchemaRegistry(str(tmp_path)) registry._load_version_dir("v1", version_dir) assert "v1" not in registry.versions - def test_load_schema_success(self, tmp_path): + def test_load_schema_success(self, tmp_path: Path): schema_path = tmp_path / "test.json" schema_content = {"title": "Test Schema", "description": "A test schema"} schema_path.write_text(json.dumps(schema_content)) diff --git a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py index e264f0befc..1f4146a9ca 100644 --- a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from pathlib import Path from unittest.mock import Mock, patch import pytest @@ -155,7 +156,7 @@ class TestSupabaseStorage: assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] mock_client.storage.from_().download.assert_called_with("test.txt") - def test_download_writes_bytes_to_disk(self, storage_with_mock_client, tmp_path): + def test_download_writes_bytes_to_disk(self, storage_with_mock_client, tmp_path: Path): """Test download writes expected bytes to disk.""" storage, mock_client = storage_with_mock_client diff --git a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py index 770344aa39..f5f27f7296 100644 --- a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py +++ b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py @@ -1,4 +1,5 @@ import json +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -52,7 +53,7 @@ class TestBuildInRecommendAppRetrieval: mock_fetch.assert_called_once_with("app-1") assert result == {"id": "app-1"} - def test_get_builtin_data_reads_json_and_caches(self, tmp_path): + def test_get_builtin_data_reads_json_and_caches(self, tmp_path: Path): json_file = tmp_path / "constants" / "recommended_apps.json" json_file.parent.mkdir(parents=True) json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA)) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index a997ea3583..eafbabe1f9 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -6,7 +6,7 @@ from datetime import UTC, datetime from itertools import cycle from threading import Event from types import SimpleNamespace -from typing import Any, cast +from typing import Any, cast, override from unittest.mock import MagicMock import pytest @@ -45,24 +45,30 @@ class _FakePauseEntity(WorkflowPauseEntity): pause_reasons: Sequence[HumanInputRequired] @property + @override def id(self) -> str: return self.pause_id @property + @override def workflow_execution_id(self) -> str: return self.workflow_run_id + @override def get_state(self) -> bytes: raise AssertionError("state is not required for snapshot tests") @property + @override def resumed_at(self) -> datetime | None: return None @property + @override def paused_at(self) -> datetime: return self.paused_at_value + @override def get_pause_reasons(self) -> Sequence[HumanInputRequired]: return self.pause_reasons @@ -291,24 +297,30 @@ class _PauseEntity(WorkflowPauseEntity): state: bytes @property + @override def id(self) -> str: return "pause-1" @property + @override def workflow_execution_id(self) -> str: return "run-1" @property + @override def resumed_at(self) -> datetime | None: return None @property + @override def paused_at(self) -> datetime: return datetime(2024, 1, 1, tzinfo=UTC) + @override def get_state(self) -> bytes: return self.state + @override def get_pause_reasons(self) -> list[Any]: return []