mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
ci: add flag for linter (#37018)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
0c4b36b3f5
commit
f15a8f02ef
@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, override
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from cachetools import LRUCache
|
||||
@ -221,6 +221,7 @@ class TracingProviderConfigEntry(TypedDict):
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
@override
|
||||
def __getitem__(self, key: str) -> TracingProviderConfigEntry:
|
||||
try:
|
||||
match key:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
@ -25,6 +25,7 @@ class CacheEmbedding(Embeddings):
|
||||
def __init__(self, model_instance: ModelInstance):
|
||||
self._model_instance = model_instance
|
||||
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -106,6 +107,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@override
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -189,6 +191,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return multimodel_embeddings
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
@ -232,6 +235,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
@override
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, TypedDict, cast, override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -61,6 +61,7 @@ class ParagraphFormatPreviewDict(TypedDict):
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@override
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting,
|
||||
@ -71,6 +72,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
@ -120,6 +122,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
|
||||
@override
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -142,6 +145,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.add_texts(documents)
|
||||
|
||||
@override
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
@ -178,6 +182,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -206,6 +211,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
documents: list[Any] = []
|
||||
all_multimodal_documents: list[Any] = []
|
||||
@ -271,6 +277,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
@override
|
||||
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
|
||||
if isinstance(chunks, list):
|
||||
preview = []
|
||||
@ -285,6 +292,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
@override
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
@ -44,6 +44,7 @@ class ParentChildFormatPreviewDict(TypedDict):
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@override
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting,
|
||||
@ -54,6 +55,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
@ -129,6 +131,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
return all_documents
|
||||
|
||||
@override
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -149,6 +152,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
@override
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
@ -219,6 +223,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -283,6 +288,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_nodes.append(child_document)
|
||||
return child_nodes
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
@ -356,6 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
@override
|
||||
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
@ -369,6 +376,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
return result
|
||||
|
||||
@override
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@ -43,6 +43,7 @@ class QAFormatPreviewDict(TypedDict):
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
@override
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting,
|
||||
@ -52,6 +53,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
preview = kwargs.get("preview")
|
||||
process_rule = kwargs.get("process_rule")
|
||||
@ -139,6 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
raise ValueError(str(e))
|
||||
return text_docs
|
||||
|
||||
@override
|
||||
def load(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
@ -153,6 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
@override
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
@ -183,6 +187,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -211,6 +216,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
@ -234,6 +240,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
@override
|
||||
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
@ -246,6 +253,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
return result
|
||||
|
||||
@override
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
from typing import override
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -16,6 +17,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
def __init__(self, rerank_model_instance: ModelInstance):
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import math
|
||||
from collections import Counter
|
||||
from typing import override
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -19,6 +20,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
self.tenant_id = tenant_id
|
||||
self.weights = weights
|
||||
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import codecs
|
||||
import re
|
||||
from collections.abc import Set as AbstractSet
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, override
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||
@ -51,6 +51,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
@override
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if self._fixed_separator:
|
||||
|
||||
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, override
|
||||
|
||||
from core.rag.models.document import BaseDocumentTransformer, Document
|
||||
|
||||
@ -148,10 +148,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@override
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
@override
|
||||
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
@ -211,6 +213,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||
|
||||
@override
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
return self._tokenizer.encode(
|
||||
@ -287,5 +290,6 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
|
||||
return final_chunks
|
||||
|
||||
@override
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
@ -105,6 +105,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
@override
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
@ -182,6 +183,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@ -14,6 +14,7 @@ from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ASRTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -56,6 +57,7 @@ class ASRTool(BuiltinTool):
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@ -12,6 +12,7 @@ from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -66,6 +67,7 @@ class TTSTool(BuiltinTool):
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
@ -8,6 +8,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pytz import timezone as pytz_timezone # type: ignore[import-untyped]
|
||||
|
||||
@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class LocaltimeToTimestampTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class TimestampToLocaltimeTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytz # type: ignore[import-untyped]
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class TimezoneConversionTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import calendar
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class WeekdayTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@ -8,6 +8,7 @@ from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
|
||||
class WebscraperTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
Validate credentials
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import override
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
@ -26,6 +28,7 @@ class BuiltinTool(Tool):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
||||
@override
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
@ -56,6 +59,7 @@ class BuiltinTool(Tool):
|
||||
caller_user_id=self.runtime.user_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import override
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -122,6 +124,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
@ -194,6 +197,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
@override
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, override
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
@ -45,6 +45,7 @@ class ApiTool(Tool):
|
||||
self.api_bundle = api_bundle
|
||||
self.provider_id = provider_id
|
||||
|
||||
@override
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime):
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
@ -77,6 +78,7 @@ class ApiTool(Tool):
|
||||
# For credential validation, always return as string
|
||||
return parsed_response.to_string()
|
||||
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
@ -373,6 +375,7 @@ class ApiTool(Tool):
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Self
|
||||
from typing import Any, Self, override
|
||||
|
||||
from core.entities.mcp_provider import IdentityMode, MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
@ -41,6 +41,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
self.identity_mode: IdentityMode = identity_mode
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
@ -116,6 +117,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
|
||||
@ -4,7 +4,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.mcp_provider import IdentityMode
|
||||
@ -58,9 +58,11 @@ class MCPTool(Tool):
|
||||
self.identity_mode: IdentityMode = identity_mode
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -232,6 +234,7 @@ class MCPTool(Tool):
|
||||
return found
|
||||
return None
|
||||
|
||||
@override
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -23,6 +23,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
@ -31,6 +32,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
@ -20,9 +20,11 @@ class PluginTool(Tool):
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters: list[ToolParameter] | None = None
|
||||
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -48,6 +50,7 @@ class PluginTool(Tool):
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
@ -57,6 +60,7 @@ class PluginTool(Tool):
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from typing import override
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
@ -46,6 +47,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
|
||||
)
|
||||
|
||||
@override
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents: list[RagDocument] = []
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -56,6 +56,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
def _run(self, query: str) -> str:
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -85,6 +85,7 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
return tools
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
@ -105,9 +106,11 @@ class DatasetRetrieverTool(Tool):
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import override
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
@ -80,6 +81,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return controller
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -67,6 +67,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
@ -75,6 +76,7 @@ class WorkflowTool(Tool):
|
||||
"""
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -206,6 +208,7 @@ class WorkflowTool(Tool):
|
||||
return found
|
||||
return None
|
||||
|
||||
@override
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
|
||||
@ -6,7 +6,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -62,6 +62,7 @@ class TriggerDebugEventPoller(ABC):
|
||||
|
||||
|
||||
class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
@override
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
|
||||
@ -103,6 +104,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
|
||||
|
||||
class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
@override
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
pool_key = build_webhook_pool_key(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -190,6 +192,7 @@ class ScheduleTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
inputs={},
|
||||
)
|
||||
|
||||
@override
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
schedule_debug_runtime = self.get_or_create_schedule_debug_runtime()
|
||||
if schedule_debug_runtime.next_run_at > naive_utc_now():
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
from typing import Union, override
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
|
||||
from core.helper.provider_cache import ProviderCredentialsCache
|
||||
@ -16,6 +16,7 @@ class TriggerProviderCredentialsCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
@ -29,6 +30,7 @@ class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
@ -41,6 +43,7 @@ class TriggerProviderPropertiesCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
|
||||
@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
import enum
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Annotated, Any, ClassVar, Literal
|
||||
from typing import Annotated, Any, ClassVar, Literal, override
|
||||
|
||||
import bleach
|
||||
import markdown
|
||||
@ -158,6 +158,7 @@ class EmailDeliveryMethod(_DeliveryMethodBase):
|
||||
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
|
||||
config: EmailDeliveryConfig
|
||||
|
||||
@override
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
variable_template_parser = VariableTemplateParser(template=self.config.body)
|
||||
selectors: list[Sequence[str]] = []
|
||||
|
||||
@ -195,13 +195,16 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
|
||||
snapshot.update(self._overrides)
|
||||
return snapshot
|
||||
|
||||
@override
|
||||
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
|
||||
return self._snapshot()[key]
|
||||
|
||||
@override
|
||||
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
|
||||
self._deleted.discard(key)
|
||||
self._overrides[key] = value
|
||||
|
||||
@override
|
||||
def __delitem__(self, key: NodeType) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
@ -211,9 +214,11 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
@override
|
||||
def __iter__(self) -> Iterator[NodeType]:
|
||||
return iter(self._snapshot())
|
||||
|
||||
@override
|
||||
def __len__(self) -> int:
|
||||
return len(self._snapshot())
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload, override
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -136,6 +136,7 @@ class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
|
||||
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
|
||||
self._run_context = resolve_dify_run_context(run_context)
|
||||
|
||||
@override
|
||||
def build_from_mapping(self, *, mapping: Mapping[str, Any]):
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
@ -151,25 +152,31 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
self._model_instance = model_instance
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider(self) -> str:
|
||||
return self._model_instance.provider
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_name(self) -> str:
|
||||
return self._model_instance.model_name
|
||||
|
||||
@property
|
||||
@override
|
||||
def parameters(self) -> Mapping[str, Any]:
|
||||
return self._model_instance.parameters
|
||||
|
||||
@parameters.setter
|
||||
@override
|
||||
def parameters(self, value: Mapping[str, Any]) -> None:
|
||||
self._model_instance.parameters = value
|
||||
|
||||
@property
|
||||
@override
|
||||
def stop(self) -> Sequence[str] | None:
|
||||
return self._model_instance.stop
|
||||
|
||||
@override
|
||||
def get_model_schema(self) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema(
|
||||
self._model_instance.model_name,
|
||||
@ -179,6 +186,7 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
raise ValueError(f"Model schema not found for {self._model_instance.model_name}")
|
||||
return model_schema
|
||||
|
||||
@override
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
@ -204,6 +212,7 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -243,6 +252,7 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
@ -263,11 +273,13 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_structured_output_parse_error(self, error: Exception) -> bool:
|
||||
return isinstance(error, OutputParserError)
|
||||
|
||||
|
||||
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
|
||||
@override
|
||||
def serialize(
|
||||
self,
|
||||
*,
|
||||
@ -294,6 +306,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
|
||||
self._file_reference_factory = file_reference_factory
|
||||
self._segment_access_checker = segment_access_checker
|
||||
|
||||
@override
|
||||
def load(self, *, segment_id: str) -> Sequence[File]:
|
||||
if not is_retriever_segment_access_granted(segment_id):
|
||||
return []
|
||||
@ -341,6 +354,7 @@ class DifyToolFileManager(ToolFileManagerProtocol):
|
||||
self._manager = ToolFileManager()
|
||||
self._conversation_id_getter = conversation_id_getter
|
||||
|
||||
@override
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
@ -358,6 +372,7 @@ class DifyToolFileManager(ToolFileManagerProtocol):
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
|
||||
return self._manager.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
|
||||
@ -394,9 +409,11 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
def file_reference_factory(self) -> FileReferenceFactoryProtocol:
|
||||
return self._file_reference_factory
|
||||
|
||||
@override
|
||||
def build_file_reference(self, *, mapping: Mapping[str, Any]):
|
||||
return self._file_reference_factory.build_from_mapping(mapping=mapping)
|
||||
|
||||
@override
|
||||
def get_runtime(
|
||||
self,
|
||||
*,
|
||||
@ -447,6 +464,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
*,
|
||||
@ -458,6 +476,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
for parameter in (tool.get_merged_runtime_parameters() or [])
|
||||
]
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
@ -503,6 +522,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
|
||||
return self._adapt_messages(transformed_messages, provider_name=provider_name)
|
||||
|
||||
@override
|
||||
def get_usage(
|
||||
self,
|
||||
*,
|
||||
@ -745,6 +765,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None:
|
||||
repo = self.build_form_repository()
|
||||
return repo.get_form(node_id)
|
||||
@ -766,6 +787,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
|
||||
)
|
||||
return restored_data
|
||||
|
||||
@override
|
||||
def create_form(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
@ -56,9 +56,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
self._message_transformer = message_transformer
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@override
|
||||
def populate_start_event(self, event) -> None:
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
event.extras["agent_strategy"] = {
|
||||
@ -69,6 +71,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
),
|
||||
}
|
||||
|
||||
@override
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
@ -167,6 +170,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import override
|
||||
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
|
||||
|
||||
|
||||
class PluginAgentStrategyResolver(AgentStrategyResolver):
|
||||
@override
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
@ -21,6 +24,7 @@ class PluginAgentStrategyResolver(AgentStrategyResolver):
|
||||
|
||||
|
||||
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
|
||||
@override
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
|
||||
@ -101,12 +101,15 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
self._session_store = session_store
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
@override
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.extras["agent_node"] = {"version": "2", "agent_node_kind": self.node_data.agent_node_kind}
|
||||
|
||||
@override
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
workflow_id = self.graph_init_params.workflow_id
|
||||
@ -577,6 +580,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
metadata["agent_backend"] = agent_backend
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
@ -49,10 +49,12 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
self.datasource_manager = DatasourceManager
|
||||
|
||||
@override
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
|
||||
event.provider_type = self.node_data.provider_type
|
||||
|
||||
@override
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
@ -183,6 +185,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
@ -219,5 +222,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
@ -46,6 +46,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
|
||||
@override
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
@ -145,6 +146,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
return rst
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ the workflow node registry.
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, override
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
@ -87,9 +87,11 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
@override
|
||||
def _run(self) -> NodeRunResult:
|
||||
usage = LLMUsage.empty_usage()
|
||||
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
||||
@ -327,6 +329,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
|
||||
@ -15,6 +15,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "plugin",
|
||||
@ -30,12 +31,15 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@override
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = self.node_data.provider_id
|
||||
|
||||
@override
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the plugin trigger node.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import override
|
||||
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
|
||||
@ -14,10 +15,12 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": TRIGGER_SCHEDULE_NODE_TYPE,
|
||||
@ -29,6 +32,7 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
@override
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id))
|
||||
system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
@ -25,12 +25,14 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
|
||||
_file_reference_factory: FileReferenceFactoryProtocol
|
||||
|
||||
@override
|
||||
def post_init(self) -> None:
|
||||
from core.workflow.node_runtime import DifyFileReferenceFactory
|
||||
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "webhook",
|
||||
@ -48,9 +50,11 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@override
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the webhook node.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
@ -11,6 +11,7 @@ from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderErr
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Sandbox-backed Jinja2 renderer for workflow-owned node composition."""
|
||||
|
||||
@override
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
|
||||
|
||||
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
|
||||
@ -52,6 +54,7 @@ class CorrelationIdGenerator(IdGenerator):
|
||||
parent-child linking), otherwise random
|
||||
"""
|
||||
|
||||
@override
|
||||
def generate_trace_id(self) -> int:
|
||||
correlation_id = _correlation_id_context.get()
|
||||
if correlation_id:
|
||||
@ -61,6 +64,7 @@ class CorrelationIdGenerator(IdGenerator):
|
||||
pass
|
||||
return random.getrandbits(128)
|
||||
|
||||
@override
|
||||
def generate_span_id(self) -> int:
|
||||
source = _span_id_source_context.get()
|
||||
if source:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from flask_restx import fields
|
||||
|
||||
@ -7,6 +8,7 @@ from libs.helper import AppIconUrlField, TimestampField
|
||||
|
||||
|
||||
class JsonStringField(fields.Raw):
|
||||
@override
|
||||
def format(self, value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from typing import override
|
||||
|
||||
from flask_restx import fields
|
||||
|
||||
from graphon.file import File
|
||||
|
||||
|
||||
class FilesContainedField(fields.Raw):
|
||||
@override
|
||||
def format(self, value):
|
||||
return self._format_file_object(value)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from flask_restx import fields
|
||||
|
||||
from core.helper import encrypter
|
||||
@ -11,6 +13,7 @@ ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER,
|
||||
|
||||
|
||||
class EnvironmentVariableField(fields.Raw):
|
||||
@override
|
||||
def format(self, value):
|
||||
# Mask secret variables values in environment_variables
|
||||
if isinstance(value, SecretVariable):
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import override
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
@ -51,6 +52,7 @@ class DefaultFieldsMixin:
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
@ -87,6 +89,7 @@ class DefaultFieldsDCMixin(MappedAsDataclass):
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, ClassVar, TypedDict, cast
|
||||
from typing import Any, ClassVar, TypedDict, cast, override
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -1790,5 +1790,6 @@ class DocumentSegmentSummary(TypeBase):
|
||||
init=False,
|
||||
)
|
||||
|
||||
@override
|
||||
def __repr__(self):
|
||||
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from typing import override
|
||||
|
||||
from core.trigger.constants import (
|
||||
TRIGGER_PLUGIN_NODE_TYPE,
|
||||
@ -12,6 +13,7 @@ class CreatorUserRole(StrEnum):
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
|
||||
@ -8,7 +8,7 @@ from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast, override
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -2058,10 +2058,12 @@ class EndUser(Base, UserMixin):
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def is_anonymous(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@is_anonymous.setter
|
||||
@override
|
||||
def is_anonymous(self, value: bool) -> None:
|
||||
self._is_anonymous = value
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from functools import cached_property
|
||||
from typing import override
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -73,6 +74,7 @@ class Provider(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@override
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}',"
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
"""Provider ID entities for plugin system."""
|
||||
|
||||
import re
|
||||
@ -14,6 +16,7 @@ class GenericProviderID:
|
||||
def to_string(self) -> str:
|
||||
return str(self)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import override
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -74,6 +75,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
match trace_info:
|
||||
case WorkflowTraceInfo():
|
||||
|
||||
@ -6,7 +6,7 @@ import traceback
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Protocol, Union, cast
|
||||
from typing import Any, Protocol, Union, cast, override
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import (
|
||||
@ -730,6 +730,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.root_span_carriers: dict[str, dict[str, str]] = {}
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dify_trace_arize_phoenix.arize_phoenix_trace as arize_phoenix_trace_module
|
||||
@ -133,10 +133,12 @@ class _CollectingSpanExporter(SpanExporter):
|
||||
def __init__(self):
|
||||
self.spans: list[ReadableSpan] = []
|
||||
|
||||
@override
|
||||
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
|
||||
self.spans.extend(spans)
|
||||
return SpanExportResult.SUCCESS
|
||||
|
||||
@override
|
||||
def shutdown(self) -> None:
|
||||
return None
|
||||
|
||||
@ -258,9 +260,11 @@ def test_set_span_status():
|
||||
|
||||
# repr branch
|
||||
class SilentError:
|
||||
@override
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
@override
|
||||
def __repr__(self):
|
||||
return "SilentErrorRepr"
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import override
|
||||
|
||||
from langfuse import Langfuse
|
||||
from langfuse.api import (
|
||||
@ -106,6 +107,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
return start_time + timedelta(seconds=ttft_seconds)
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
match trace_info:
|
||||
case WorkflowTraceInfo():
|
||||
|
||||
@ -2,6 +2,7 @@ import collections
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import override
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -738,6 +739,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
|
||||
node.status = "succeeded"
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
@override
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast
|
||||
from typing import cast, override
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
@ -47,6 +47,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
match trace_info:
|
||||
case WorkflowTraceInfo():
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import collections
|
||||
from datetime import datetime, timedelta
|
||||
from typing import override
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -549,6 +550,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pyte
|
||||
)
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
@override
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
import mlflow
|
||||
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
|
||||
@ -86,6 +86,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces"
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
"""Simple dispatch to trace methods"""
|
||||
try:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
@ -95,6 +95,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
self.project = opik_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
match trace_info:
|
||||
case WorkflowTraceInfo():
|
||||
|
||||
@ -2,6 +2,7 @@ import collections
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import override
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -643,6 +644,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch
|
||||
node.status = "succeeded"
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
@override
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
"""Tencent APM tracing with idempotent client cleanup."""
|
||||
|
||||
import inspect
|
||||
@ -56,6 +58,7 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
metrics_export_interval_sec=5,
|
||||
)
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
"""Main tracing entry point - coordinates different trace types."""
|
||||
match trace_info:
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
@ -78,6 +78,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||
|
||||
@override
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.debug("Trace info: %s", trace_info)
|
||||
match trace_info:
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal, cast, override
|
||||
|
||||
import mysql.connector
|
||||
from mysql.connector import Error as MySQLError
|
||||
@ -81,6 +81,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
self.hnsw_m = config.hnsw_m
|
||||
self._check_vector_support()
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ALIBABACLOUD_MYSQL
|
||||
|
||||
@ -135,11 +136,13 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
@ -165,6 +168,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
cur.executemany(insert_sql, values)
|
||||
return pks
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
@ -183,6 +187,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
# Avoiding crashes caused by performing delete operations on empty lists
|
||||
if not ids:
|
||||
@ -199,12 +204,14 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
else:
|
||||
raise e
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
|
||||
@ -274,6 +281,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
@ -308,6 +316,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
@ -355,6 +364,7 @@ class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
|
||||
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
|
||||
return cast(Literal["cosine", "euclidean"], distance_function)
|
||||
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import unittest
|
||||
from typing import override
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -22,6 +23,7 @@ except ImportError:
|
||||
|
||||
|
||||
class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
@override
|
||||
def setUp(self):
|
||||
self.config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@ -32,38 +32,48 @@ class AnalyticdbVector(BaseVector):
|
||||
raise ValueError("Either api_config or sql_config must be provided")
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self.analyticdb_vector.create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
return []
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
self.analyticdb_vector.delete_by_ids(ids)
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
@ -40,6 +42,7 @@ class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
return super().run_all_tests()
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -82,6 +82,7 @@ class BaiduVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._db = self._init_database()
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.BAIDU
|
||||
|
||||
@ -92,10 +93,12 @@ class BaiduVector(BaseVector):
|
||||
}
|
||||
return result
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._create_table(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
@ -116,23 +119,27 @@ class BaiduVector(BaseVector):
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
if res and res.code == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
# Escape double quotes in value to prevent injection
|
||||
escaped_value = value.replace('"', '\\"')
|
||||
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
@ -154,6 +161,7 @@ class BaiduVector(BaseVector):
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# document ids filter
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
@ -189,6 +197,7 @@ class BaiduVector(BaseVector):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
try:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
@ -368,6 +377,7 @@ class BaiduVector(BaseVector):
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
@ -18,10 +20,12 @@ class BaiduVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
@override
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
@ -55,9 +55,11 @@ class ChromaVector(BaseVector):
|
||||
self._client_config = config
|
||||
self._client = chromadb.HttpClient(**self._client_config.to_chroma_params())
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.CHROMA
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
# create collection
|
||||
@ -74,6 +76,7 @@ class ChromaVector(BaseVector):
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
@ -84,25 +87,30 @@ class ChromaVector(BaseVector):
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
return uuids
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
# FIXME: fix the type error later
|
||||
collection.delete(where={key: {"$eq": value}}) # type: ignore
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
response = collection.get(ids=[id])
|
||||
return len(response) > 0
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
@ -142,12 +150,14 @@ class ChromaVector(BaseVector):
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# chroma does not support BM25 full text searching
|
||||
return []
|
||||
|
||||
|
||||
class ChromaVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
import chromadb
|
||||
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
|
||||
@ -22,6 +24,7 @@ class ChromaVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self):
|
||||
# chroma dos not support full text searching
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
|
||||
@ -8,7 +8,7 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -436,6 +436,7 @@ class ClickzettaVector(BaseVector):
|
||||
raise result
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
"""Return the vector database type."""
|
||||
return "clickzetta"
|
||||
@ -469,6 +470,7 @@ class ClickzettaVector(BaseVector):
|
||||
)
|
||||
return False
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create the collection and add initial documents."""
|
||||
# Execute table creation through write queue to avoid concurrent conflicts
|
||||
@ -606,6 +608,7 @@ class ClickzettaVector(BaseVector):
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
@ -716,6 +719,7 @@ class ClickzettaVector(BaseVector):
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
|
||||
raise
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a document exists by ID."""
|
||||
# Check if table exists first
|
||||
@ -732,6 +736,7 @@ class ClickzettaVector(BaseVector):
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0 if result else False
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
"""Delete documents by IDs."""
|
||||
if not ids:
|
||||
@ -757,6 +762,7 @@ class ClickzettaVector(BaseVector):
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, binding_params=safe_ids)
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""Delete documents by metadata field."""
|
||||
# Check if table exists before attempting delete
|
||||
@ -780,6 +786,7 @@ class ClickzettaVector(BaseVector):
|
||||
)
|
||||
cursor.execute(sql, binding_params=[value])
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
# Check if table exists first
|
||||
@ -865,6 +872,7 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
return documents
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents using full-text search with inverted index."""
|
||||
if not self._config.enable_inverted_index:
|
||||
@ -1031,6 +1039,7 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
return documents
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
"""Delete the entire collection."""
|
||||
with self.get_connection_context() as connection:
|
||||
@ -1057,6 +1066,7 @@ class ClickzettaVector(BaseVector):
|
||||
class ClickzettaVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating Clickzetta vector instances."""
|
||||
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
"""Initialize a Clickzetta vector instance."""
|
||||
# Get configuration from environment variables or dataset config
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from couchbase import search # type: ignore
|
||||
from couchbase.auth import PasswordAuthenticator # type: ignore
|
||||
@ -68,6 +68,7 @@ class CouchbaseVector(BaseVector):
|
||||
# Wait until the cluster is ready for use.
|
||||
self._cluster.wait_until_ready(timedelta(seconds=5))
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_id = str(uuid.uuid4()).replace("-", "")
|
||||
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
|
||||
@ -200,9 +201,11 @@ class CouchbaseVector(BaseVector):
|
||||
# Check if the collection exists in the scope
|
||||
return self._collection_name in scope_collection_map[self._scope_name]
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.COUCHBASE
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
@ -221,6 +224,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
return doc_ids
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
# Use a parameterized query for safety and correctness
|
||||
query = f"""
|
||||
@ -234,6 +238,7 @@ class CouchbaseVector(BaseVector):
|
||||
return bool(row["count"] > 0)
|
||||
return False # Return False if no rows are returned
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
query = f"""
|
||||
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
@ -261,6 +266,7 @@ class CouchbaseVector(BaseVector):
|
||||
# result = self._cluster.query(query, named_parameters={'value':value})
|
||||
# return [row['id'] for row in result.rows()]
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
query = f"""
|
||||
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
@ -268,6 +274,7 @@ class CouchbaseVector(BaseVector):
|
||||
"""
|
||||
self._cluster.query(query, named_parameters={"value": value}).execute()
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
@ -303,6 +310,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
@ -325,6 +333,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
manager = self._bucket.collections()
|
||||
scopes = manager.get_all_scopes()
|
||||
@ -356,6 +365,7 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
|
||||
class CouchbaseVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
|
||||
@ -40,6 +41,7 @@ class CouchbaseTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_vector(self):
|
||||
# brief sleep to ensure document is indexed
|
||||
time.sleep(5)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchJaVector(ElasticSearchVector):
|
||||
@override
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
@ -82,6 +83,7 @@ class ElasticSearchJaVector(ElasticSearchVector):
|
||||
|
||||
|
||||
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from elasticsearch import ConnectionError as ElasticsearchConnectionError
|
||||
@ -154,9 +154,11 @@ class ElasticSearchVector(BaseVector):
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ELASTICSEARCH
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
@ -172,15 +174,18 @@ class ElasticSearchVector(BaseVector):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
return uuids
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
@ -188,9 +193,11 @@ class ElasticSearchVector(BaseVector):
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
self._client.indices.delete(index=self._collection_name)
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
@ -224,6 +231,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
@ -249,6 +257,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
@ -297,6 +306,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
|
||||
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
import holo_search_sdk as holo # type: ignore
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
@ -81,15 +81,18 @@ class HologresVector(BaseVector):
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.HOLOGRES
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create collection table with vector and full-text indexes, then add texts."""
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add texts with embeddings to the collection using batch upsert."""
|
||||
if not documents:
|
||||
@ -127,6 +130,7 @@ class HologresVector(BaseVector):
|
||||
|
||||
return pks
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a text with the given doc_id exists in the collection."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
@ -140,6 +144,7 @@ class HologresVector(BaseVector):
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
|
||||
"""Get document IDs by metadata field key and value."""
|
||||
result = self._client.execute(
|
||||
@ -152,6 +157,7 @@ class HologresVector(BaseVector):
|
||||
return [row[0] for row in result]
|
||||
return None
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
"""Delete documents by their doc_id list."""
|
||||
if not ids:
|
||||
@ -166,6 +172,7 @@ class HologresVector(BaseVector):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""Delete documents by metadata field key and value."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
@ -177,6 +184,7 @@ class HologresVector(BaseVector):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
@ -229,6 +237,7 @@ class HologresVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by full-text search."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
@ -272,6 +281,7 @@ class HologresVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
"""Delete the entire collection table."""
|
||||
if self._client.check_table_exist(self.table_name):
|
||||
@ -333,6 +343,7 @@ class HologresVector(BaseVector):
|
||||
class HologresVectorFactory(AbstractVectorFactory):
|
||||
"""Factory class for creating HologresVector instances."""
|
||||
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import cast
|
||||
from typing import cast, override
|
||||
|
||||
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
@ -35,6 +35,7 @@ class HologresVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self):
|
||||
"""Override: full-text index may not be immediately ready in real mode."""
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
@ -97,12 +98,14 @@ class HologresVectorTest(AbstractVectorTest):
|
||||
assert len(hits) == 1
|
||||
assert hits[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
"""Override: Hologres implements this method via JSONB query."""
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
|
||||
@override
|
||||
def run_all_tests(self):
|
||||
# Clean up before running tests
|
||||
self.vector.delete()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -68,9 +68,11 @@ class HuaweiCloudVector(BaseVector):
|
||||
super().__init__(index_name.lower())
|
||||
self._client = Elasticsearch(**config.to_elasticsearch_params())
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.HUAWEI_CLOUD
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
@ -86,15 +88,18 @@ class HuaweiCloudVector(BaseVector):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
return uuids
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
@ -102,9 +107,11 @@ class HuaweiCloudVector(BaseVector):
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
self._client.indices.delete(index=self._collection_name)
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
|
||||
@ -145,6 +152,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
@ -160,6 +168,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
@ -207,6 +216,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
|
||||
|
||||
class HuaweiCloudVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HuaweiCloudVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
@ -15,10 +17,12 @@ class HuaweiCloudVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 3
|
||||
|
||||
@override
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 3
|
||||
|
||||
@ -11,7 +11,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.iris_config import IrisVectorConfig
|
||||
@ -188,6 +188,7 @@ class IrisVector(BaseVector):
|
||||
self.schema = config.IRIS_SCHEMA or "dify"
|
||||
self.pool = get_iris_pool(config)
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.IRIS
|
||||
|
||||
@ -206,11 +207,13 @@ class IrisVector(BaseVector):
|
||||
cursor.close()
|
||||
self.pool.return_connection(conn)
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
added_ids = []
|
||||
@ -226,6 +229,7 @@ class IrisVector(BaseVector):
|
||||
|
||||
return added_ids
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
|
||||
try:
|
||||
with self._get_cursor() as cursor:
|
||||
@ -235,6 +239,7 @@ class IrisVector(BaseVector):
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
@ -244,6 +249,7 @@ class IrisVector(BaseVector):
|
||||
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
|
||||
cursor.execute(sql, ids)
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field (JSON LIKE pattern matching)."""
|
||||
with self._get_cursor() as cursor:
|
||||
@ -251,6 +257,7 @@ class IrisVector(BaseVector):
|
||||
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
|
||||
cursor.execute(sql, (pattern,))
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search similar documents using VECTOR_COSINE with HNSW index."""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
@ -275,6 +282,7 @@ class IrisVector(BaseVector):
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search documents by full-text using iFind index with BM25 relevance scoring.
|
||||
|
||||
@ -404,6 +412,7 @@ class IrisVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection (drop table - permanent)."""
|
||||
with self._get_cursor() as cursor:
|
||||
@ -481,6 +490,7 @@ class IrisVector(BaseVector):
|
||||
class IrisVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating IrisVector instances."""
|
||||
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
@ -79,9 +79,11 @@ class LindormVectorStore(BaseVector):
|
||||
self._using_ugc = using_ugc
|
||||
self.kwargs = kwargs
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.LINDORM
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
@ -90,6 +92,7 @@ class LindormVectorStore(BaseVector):
|
||||
def refresh(self):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
||||
@override
|
||||
def add_texts(
|
||||
self,
|
||||
documents: list[Document],
|
||||
@ -156,6 +159,7 @@ class LindormVectorStore(BaseVector):
|
||||
logger.exception("Failed to process batch %s", batch_num + 1)
|
||||
raise
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
|
||||
@ -168,11 +172,13 @@ class LindormVectorStore(BaseVector):
|
||||
else:
|
||||
return None
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
"""Delete documents by their IDs in batch.
|
||||
|
||||
@ -219,6 +225,7 @@ class LindormVectorStore(BaseVector):
|
||||
else:
|
||||
logger.exception("Error deleting document: %s", error)
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
if self._using_ugc:
|
||||
routing_filter_query = {
|
||||
@ -233,6 +240,7 @@ class LindormVectorStore(BaseVector):
|
||||
else:
|
||||
logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
params: dict[str, Any] = {}
|
||||
@ -243,6 +251,7 @@ class LindormVectorStore(BaseVector):
|
||||
except:
|
||||
return False
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
if not isinstance(query_vector, list):
|
||||
raise ValueError("query_vector should be a list of floats")
|
||||
@ -305,6 +314,7 @@ class LindormVectorStore(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
|
||||
filters = []
|
||||
@ -377,6 +387,7 @@ class LindormVectorStore(BaseVector):
|
||||
|
||||
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
|
||||
@ -26,6 +27,7 @@ class TestLindormVectorStore(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
@ -47,6 +49,7 @@ class TestLindormVectorStoreUGC(AbstractVectorTest):
|
||||
routing_value=self.collection_name,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate
|
||||
from typing import Any, Concatenate, override
|
||||
|
||||
from mo_vector.client import MoVectorClient # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -69,16 +69,20 @@ class MatrixoneVector(BaseVector):
|
||||
self.client = None
|
||||
|
||||
@property
|
||||
@override
|
||||
def collection_name(self):
|
||||
return self._collection_name
|
||||
|
||||
@collection_name.setter
|
||||
@override
|
||||
def collection_name(self, value):
|
||||
self._collection_name = value
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MATRIXONE
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(len(embeddings[0]), True)
|
||||
@ -108,6 +112,7 @@ class MatrixoneVector(BaseVector):
|
||||
logger.exception("Failed to create full text index")
|
||||
return client
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(len(embeddings[0]), True)
|
||||
@ -126,12 +131,14 @@ class MatrixoneVector(BaseVector):
|
||||
return ids
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
assert self.client is not None
|
||||
result = self.client.get(ids=[id])
|
||||
return len(result) > 0
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
assert self.client is not None
|
||||
if not ids:
|
||||
@ -139,17 +146,20 @@ class MatrixoneVector(BaseVector):
|
||||
self.client.delete(ids=ids)
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
assert self.client is not None
|
||||
results = self.client.query_by_metadata(filter={key: value})
|
||||
return [result.id for result in results]
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
assert self.client is not None
|
||||
self.client.delete(filter={key: value})
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
assert self.client is not None
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
@ -177,6 +187,7 @@ class MatrixoneVector(BaseVector):
|
||||
return docs
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
assert self.client is not None
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
@ -207,12 +218,14 @@ class MatrixoneVector(BaseVector):
|
||||
return docs
|
||||
|
||||
@ensure_client
|
||||
@override
|
||||
def delete(self):
|
||||
assert self.client is not None
|
||||
self.client.delete()
|
||||
|
||||
|
||||
class MatrixoneVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
@ -15,6 +17,7 @@ class MatrixoneVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, TypedDict, cast, override
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -120,12 +120,14 @@ class MilvusVector(BaseVector):
|
||||
logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e))
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
"""
|
||||
Get the type of vector storage (Milvus).
|
||||
"""
|
||||
return VectorType.MILVUS
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Create a collection and add texts with embeddings.
|
||||
@ -135,6 +137,7 @@ class MilvusVector(BaseVector):
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Add texts and their embeddings to the collection.
|
||||
@ -164,6 +167,7 @@ class MilvusVector(BaseVector):
|
||||
raise e
|
||||
return pks
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Get document IDs by metadata field key and value.
|
||||
@ -176,6 +180,7 @@ class MilvusVector(BaseVector):
|
||||
else:
|
||||
return None
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Delete documents by metadata field key and value.
|
||||
@ -185,6 +190,7 @@ class MilvusVector(BaseVector):
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
"""
|
||||
Delete documents by their IDs.
|
||||
@ -197,6 +203,7 @@ class MilvusVector(BaseVector):
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
"""
|
||||
Delete the entire collection.
|
||||
@ -204,6 +211,7 @@ class MilvusVector(BaseVector):
|
||||
if self._client.has_collection(self._collection_name):
|
||||
self._client.drop_collection(self._collection_name, None)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""
|
||||
Check if a text with the given ID exists in the collection.
|
||||
@ -245,6 +253,7 @@ class MilvusVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
@ -269,6 +278,7 @@ class MilvusVector(BaseVector):
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search for documents by full-text search (if hybrid search is enabled).
|
||||
@ -411,6 +421,7 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
Factory class for creating MilvusVector instances.
|
||||
"""
|
||||
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
"""
|
||||
Initialize a MilvusVector instance for the given dataset.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
@ -18,11 +20,13 @@ class MilvusVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self):
|
||||
# milvus support BM25 full text search after version 2.5.0-beta
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from clickhouse_connect import get_client # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
@ -46,9 +46,11 @@ class MyScaleVector(BaseVector):
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@ -71,6 +73,7 @@ class MyScaleVector(BaseVector):
|
||||
"""
|
||||
self._client.command(sql)
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
@ -97,10 +100,12 @@ class MyScaleVector(BaseVector):
|
||||
def escape_str(value: Any) -> str:
|
||||
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
return results.row_count > 0
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
@ -108,20 +113,24 @@ class MyScaleVector(BaseVector):
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
)
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
|
||||
@ -156,11 +165,13 @@ class MyScaleVector(BaseVector):
|
||||
logger.exception("Vector search operation failed")
|
||||
return []
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
@ -20,6 +22,7 @@ class MyScaleVectorTest(AbstractVectorTest):
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, override
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore
|
||||
@ -86,6 +86,7 @@ class OceanBaseVector(BaseVector):
|
||||
self._load_collection_fields()
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OCEANBASE
|
||||
|
||||
@ -114,6 +115,7 @@ class OceanBaseVector(BaseVector):
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._vec_dim = len(embeddings[0])
|
||||
self._create_collection()
|
||||
@ -237,6 +239,7 @@ class OceanBaseVector(BaseVector):
|
||||
logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e))
|
||||
return False
|
||||
|
||||
@override
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
batch_size = self._config.batch_size
|
||||
@ -283,6 +286,7 @@ class OceanBaseVector(BaseVector):
|
||||
self._collection_name,
|
||||
)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
@ -295,6 +299,7 @@ class OceanBaseVector(BaseVector):
|
||||
)
|
||||
raise Exception(f"Failed to check text existence for id '{id}'") from e
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
@ -309,6 +314,7 @@ class OceanBaseVector(BaseVector):
|
||||
)
|
||||
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
try:
|
||||
import re
|
||||
@ -343,6 +349,7 @@ class OceanBaseVector(BaseVector):
|
||||
)
|
||||
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
|
||||
|
||||
@override
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
@ -381,6 +388,7 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
return docs
|
||||
|
||||
@override
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
if not self._hybrid_search_enabled:
|
||||
logger.warning(
|
||||
@ -438,6 +446,7 @@ class OceanBaseVector(BaseVector):
|
||||
)
|
||||
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
@override
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from sqlalchemy import text
|
||||
|
||||
@ -508,6 +517,7 @@ class OceanBaseVector(BaseVector):
|
||||
return -distance
|
||||
raise ValueError(f"Unsupported metric_type '{metric}'")
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
try:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
@ -518,6 +528,7 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
@override
|
||||
def init_vector(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
import pytest
|
||||
from dify_vdb_oceanbase.oceanbase_vector import (
|
||||
OceanBaseVector,
|
||||
@ -30,6 +32,7 @@ class OceanBaseVectorTest(AbstractVectorTest):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
@override
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user