ci: add flag for linter (#37018)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-08 13:53:12 +09:00 committed by GitHub
parent 0c4b36b3f5
commit f15a8f02ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
171 changed files with 885 additions and 135 deletions

View File

@ -7,7 +7,7 @@ import threading
import time
from collections.abc import Mapping
from datetime import timedelta
from typing import TYPE_CHECKING, Any, TypedDict
from typing import TYPE_CHECKING, Any, TypedDict, override
from uuid import UUID, uuid4
from cachetools import LRUCache
@ -221,6 +221,7 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
@override
def __getitem__(self, key: str) -> TracingProviderConfigEntry:
try:
match key:

View File

@ -1,7 +1,7 @@
import base64
import logging
import pickle
from typing import Any, cast
from typing import Any, cast, override
import numpy as np
from sqlalchemy import select
@ -25,6 +25,7 @@ class CacheEmbedding(Embeddings):
def __init__(self, model_instance: ModelInstance):
self._model_instance = model_instance
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists
@ -106,6 +107,7 @@ class CacheEmbedding(Embeddings):
return text_embeddings
@override
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
@ -189,6 +191,7 @@ class CacheEmbedding(Embeddings):
return multimodel_embeddings
@override
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
@ -232,6 +235,7 @@ class CacheEmbedding(Embeddings):
return embedding_results # type: ignore
@override
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists

View File

@ -3,7 +3,7 @@
import logging
import re
import uuid
from typing import Any, TypedDict, cast
from typing import Any, TypedDict, cast, override
logger = logging.getLogger(__name__)
@ -61,6 +61,7 @@ class ParagraphFormatPreviewDict(TypedDict):
class ParagraphIndexProcessor(BaseIndexProcessor):
@override
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
@ -71,6 +72,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return text_docs
@override
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
@ -120,6 +122,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
@override
def load(
self,
dataset: Dataset,
@ -142,6 +145,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.add_texts(documents)
@override
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
@ -178,6 +182,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.delete()
@override
def retrieve(
self,
retrieval_method: RetrievalMethod,
@ -206,6 +211,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
@override
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
@ -271,6 +277,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset)
keyword.add_texts(documents)
@override
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
if isinstance(chunks, list):
preview = []
@ -285,6 +292,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
raise ValueError("Chunks is not a list")
@override
def generate_summary_preview(
self,
tenant_id: str,

View File

@ -3,7 +3,7 @@
import json
import logging
import uuid
from typing import Any, TypedDict
from typing import Any, TypedDict, override
from sqlalchemy import delete, select
@ -44,6 +44,7 @@ class ParentChildFormatPreviewDict(TypedDict):
class ParentChildIndexProcessor(BaseIndexProcessor):
@override
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
@ -54,6 +55,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return text_docs
@override
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
@ -129,6 +131,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return all_documents
@override
def load(
self,
dataset: Dataset,
@ -149,6 +152,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
@override
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
@ -219,6 +223,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
)
db.session.commit()
@override
def retrieve(
self,
retrieval_method: RetrievalMethod,
@ -283,6 +288,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes.append(child_document)
return child_nodes
@override
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
documents = []
@ -356,6 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
@override
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
@ -369,6 +376,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
return result
@override
def generate_summary_preview(
self,
tenant_id: str,

View File

@ -4,7 +4,7 @@ import logging
import re
import threading
import uuid
from typing import Any, TypedDict
from typing import Any, TypedDict, override
import pandas as pd
from flask import Flask, current_app
@ -43,6 +43,7 @@ class QAFormatPreviewDict(TypedDict):
class QAIndexProcessor(BaseIndexProcessor):
@override
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
@ -52,6 +53,7 @@ class QAIndexProcessor(BaseIndexProcessor):
)
return text_docs
@override
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
@ -139,6 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
@override
def load(
self,
dataset: Dataset,
@ -153,6 +156,7 @@ class QAIndexProcessor(BaseIndexProcessor):
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
@override
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
@ -183,6 +187,7 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
vector.delete()
@override
def retrieve(
self,
retrieval_method: RetrievalMethod,
@ -211,6 +216,7 @@ class QAIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
@override
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
qa_chunks = QAStructureChunk.model_validate(chunks)
documents = []
@ -234,6 +240,7 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
raise ValueError("Indexing technique must be high quality.")
@override
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
@ -246,6 +253,7 @@ class QAIndexProcessor(BaseIndexProcessor):
}
return result
@override
def generate_summary_preview(
self,
tenant_id: str,

View File

@ -1,4 +1,5 @@
import base64
from typing import override
from core.model_manager import ModelInstance, ModelManager
from core.rag.index_processor.constant.doc_type import DocType
@ -16,6 +17,7 @@ class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance):
self.rerank_model_instance = rerank_model_instance
@override
def run(
self,
query: str,

View File

@ -1,5 +1,6 @@
import math
from collections import Counter
from typing import override
import numpy as np
@ -19,6 +20,7 @@ class WeightRerankRunner(BaseRerankRunner):
self.tenant_id = tenant_id
self.weights = weights
@override
def run(
self,
query: str,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import codecs
import re
from collections.abc import Set as AbstractSet
from typing import Any, Literal
from typing import Any, Literal, override
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
@ -51,6 +51,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
@override
def split_text(self, text: str) -> list[str]:
"""Split incoming text and return chunks."""
if self._fixed_separator:

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Set as AbstractSet
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, Literal, override
from core.rag.models.document import BaseDocumentTransformer, Document
@ -148,10 +148,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
@override
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
@override
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
@ -211,6 +213,7 @@ class TokenTextSplitter(TextSplitter):
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
@override
def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]:
return self._tokenizer.encode(
@ -287,5 +290,6 @@ class RecursiveCharacterTextSplitter(TextSplitter):
return final_chunks
@override
def split_text(self, text: str) -> list[str]:
return self._split_text(text, self._separators)

View File

@ -1,6 +1,6 @@
from abc import abstractmethod
from os import listdir, path
from typing import Any
from typing import Any, override
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
@ -105,6 +105,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.tools
@override
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
@ -182,6 +183,7 @@ class BuiltinToolProviderController(ToolProviderController):
)
@property
@override
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider

View File

@ -1,8 +1,9 @@
from typing import Any
from typing import Any, override
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AudioToolProvider(BuiltinToolProviderController):
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@ -1,6 +1,6 @@
import io
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.model_manager import ModelManager
from core.plugin.entities.parameters import PluginParameterOption
@ -14,6 +14,7 @@ from services.model_provider_service import ModelProviderService
class ASRTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,
@ -56,6 +57,7 @@ class ASRTool(BuiltinTool):
items.append((provider, model.model))
return items
@override
def get_runtime_parameters(
self,
conversation_id: str | None = None,

View File

@ -1,6 +1,6 @@
import io
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.model_manager import ModelManager
from core.plugin.entities.parameters import PluginParameterOption
@ -12,6 +12,7 @@ from services.model_provider_service import ModelProviderService
class TTSTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,
@ -66,6 +67,7 @@ class TTSTool(BuiltinTool):
items.append((provider, model.model, voices))
return items
@override
def get_runtime_parameters(
self,
conversation_id: str | None = None,

View File

@ -1,8 +1,9 @@
from typing import Any
from typing import Any, override
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class CodeToolProvider(BuiltinToolProviderController):
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.tools.builtin_tool.tool import BuiltinTool
@ -8,6 +8,7 @@ from core.tools.errors import ToolInvokeError
class SimpleCode(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,8 +1,9 @@
from typing import Any
from typing import Any, override
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WikiPediaProvider(BuiltinToolProviderController):
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import UTC, datetime
from typing import Any
from typing import Any, override
from pytz import timezone as pytz_timezone # type: ignore[import-untyped]
@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
class CurrentTimeTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import Any, override
import pytz # type: ignore[import-untyped]
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
class LocaltimeToTimestampTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import Any, override
import pytz # type: ignore[import-untyped]
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
class TimestampToLocaltimeTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import Any, override
import pytz # type: ignore[import-untyped]
@ -10,6 +10,7 @@ from core.tools.errors import ToolInvokeError
class TimezoneConversionTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,13 +1,14 @@
import calendar
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import Any, override
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class WeekdayTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -8,6 +8,7 @@ from core.tools.utils.web_reader_tool import get_url
class WebscraperTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,

View File

@ -1,9 +1,10 @@
from typing import Any
from typing import Any, override
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WebscraperProvider(BuiltinToolProviderController):
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
Validate credentials

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from typing import override
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolProviderType
@ -26,6 +28,7 @@ class BuiltinTool(Tool):
super().__init__(**kwargs)
self.provider = provider
@override
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
"""
fork a new tool with metadata
@ -56,6 +59,7 @@ class BuiltinTool(Tool):
caller_user_id=self.runtime.user_id,
)
@override
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from typing import override
from pydantic import Field
from sqlalchemy import select
@ -122,6 +124,7 @@ class ApiToolProviderController(ToolProviderController):
)
@property
@override
def provider_type(self) -> ToolProviderType:
return ToolProviderType.API
@ -194,6 +197,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = tools
return tools
@override
def get_tool(self, tool_name: str) -> ApiTool:
"""
get tool by name

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Generator
from dataclasses import dataclass
from os import getenv
from typing import Any, Union
from typing import Any, Union, override
from urllib.parse import urlencode
import httpx
@ -45,6 +45,7 @@ class ApiTool(Tool):
self.api_bundle = api_bundle
self.provider_id = provider_id
@override
def fork_tool_runtime(self, runtime: ToolRuntime):
"""
fork a new tool with metadata
@ -77,6 +78,7 @@ class ApiTool(Tool):
# For credential validation, always return as string
return parsed_response.to_string()
@override
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.API
@ -373,6 +375,7 @@ class ApiTool(Tool):
except ValueError:
return value
@override
def _invoke(
self,
user_id: str,

View File

@ -1,4 +1,4 @@
from typing import Any, Self
from typing import Any, Self, override
from core.entities.mcp_provider import IdentityMode, MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
@ -41,6 +41,7 @@ class MCPToolProviderController(ToolProviderController):
self.identity_mode: IdentityMode = identity_mode
@property
@override
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
@ -116,6 +117,7 @@ class MCPToolProviderController(ToolProviderController):
"""
pass
@override
def get_tool(self, tool_name: str) -> MCPTool:
"""
return tool with given name

View File

@ -4,7 +4,7 @@ import base64
import json
import logging
from collections.abc import Generator, Mapping
from typing import Any, cast
from typing import Any, cast, override
from configs import dify_config
from core.entities.mcp_provider import IdentityMode
@ -58,9 +58,11 @@ class MCPTool(Tool):
self.identity_mode: IdentityMode = identity_mode
self._latest_usage = LLMUsage.empty_usage()
@override
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@override
def _invoke(
self,
user_id: str,
@ -232,6 +234,7 @@ class MCPTool(Tool):
return found
return None
@override
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
@ -23,6 +23,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
@ -31,6 +32,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
"""
return ToolProviderType.PLUGIN
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.plugin.impl.tool import PluginToolManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -20,9 +20,11 @@ class PluginTool(Tool):
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters: list[ToolParameter] | None = None
@override
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.PLUGIN
@override
def _invoke(
self,
user_id: str,
@ -48,6 +50,7 @@ class PluginTool(Tool):
message_id=message_id,
)
@override
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
return PluginTool(
entity=self.entity,
@ -57,6 +60,7 @@ class PluginTool(Tool):
plugin_unique_identifier=self.plugin_unique_identifier,
)
@override
def get_runtime_parameters(
self,
conversation_id: str | None = None,

View File

@ -1,4 +1,5 @@
import threading
from typing import override
from flask import Flask, current_app
from pydantic import BaseModel, Field
@ -46,6 +47,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
)
@override
def _run(self, query: str) -> str:
threads = []
all_documents: list[RagDocument] = []

View File

@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, cast, override
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -56,6 +56,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
**kwargs,
)
@override
def _run(self, query: str) -> str:
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@ -85,6 +85,7 @@ class DatasetRetrieverTool(Tool):
return tools
@override
def get_runtime_parameters(
self,
conversation_id: str | None = None,
@ -105,9 +106,11 @@ class DatasetRetrieverTool(Tool):
),
]
@override
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
@override
def _invoke(
self,
user_id: str,

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import override
from pydantic import Field
from sqlalchemy import select
@ -80,6 +81,7 @@ class WorkflowToolProviderController(ToolProviderController):
return controller
@property
@override
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, cast, override
from sqlalchemy import select
@ -67,6 +67,7 @@ class WorkflowTool(Tool):
super().__init__(entity=entity, runtime=runtime)
@override
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
@ -75,6 +76,7 @@ class WorkflowTool(Tool):
"""
return ToolProviderType.WORKFLOW
@override
def _invoke(
self,
user_id: str,
@ -206,6 +208,7 @@ class WorkflowTool(Tool):
return found
return None
@override
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
"""
fork a new tool with metadata

View File

@ -6,7 +6,7 @@ import time
from abc import ABC, abstractmethod
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from typing import Any, override
from pydantic import BaseModel
@ -62,6 +62,7 @@ class TriggerDebugEventPoller(ABC):
class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
@override
def poll(self) -> TriggerDebugEvent | None:
from services.trigger.trigger_service import TriggerService
@ -103,6 +104,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller):
@override
def poll(self) -> TriggerDebugEvent | None:
pool_key = build_webhook_pool_key(
tenant_id=self.tenant_id,
@ -190,6 +192,7 @@ class ScheduleTriggerDebugEventPoller(TriggerDebugEventPoller):
inputs={},
)
@override
def poll(self) -> TriggerDebugEvent | None:
schedule_debug_runtime = self.get_or_create_schedule_debug_runtime()
if schedule_debug_runtime.next_run_at > naive_utc_now():

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Union
from typing import Union, override
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
from core.helper.provider_cache import ProviderCredentialsCache
@ -16,6 +16,7 @@ class TriggerProviderCredentialsCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
@ -29,6 +30,7 @@ class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
@ -41,6 +43,7 @@ class TriggerProviderPropertiesCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]

View File

@ -10,7 +10,7 @@ from __future__ import annotations
import enum
import uuid
from collections.abc import Mapping, Sequence
from typing import Annotated, Any, ClassVar, Literal
from typing import Annotated, Any, ClassVar, Literal, override
import bleach
import markdown
@ -158,6 +158,7 @@ class EmailDeliveryMethod(_DeliveryMethodBase):
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
config: EmailDeliveryConfig
@override
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
variable_template_parser = VariableTemplateParser(template=self.config.body)
selectors: list[Sequence[str]] = []

View File

@ -195,13 +195,16 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
snapshot.update(self._overrides)
return snapshot
@override
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
return self._snapshot()[key]
@override
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
self._deleted.discard(key)
self._overrides[key] = value
@override
def __delitem__(self, key: NodeType) -> None:
if key in self._overrides:
del self._overrides[key]
@ -211,9 +214,11 @@ class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Nod
return
raise KeyError(key)
@override
def __iter__(self) -> Iterator[NodeType]:
return iter(self._snapshot())
@override
def __len__(self) -> int:
return len(self._snapshot())

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from typing import TYPE_CHECKING, Any, Literal, cast, overload, override
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -136,6 +136,7 @@ class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None:
self._run_context = resolve_dify_run_context(run_context)
@override
def build_from_mapping(self, *, mapping: Mapping[str, Any]):
return file_factory.build_from_mapping(
mapping=mapping,
@ -151,25 +152,31 @@ class DifyPreparedLLM(LLMProtocol):
self._model_instance = model_instance
@property
@override
def provider(self) -> str:
return self._model_instance.provider
@property
@override
def model_name(self) -> str:
return self._model_instance.model_name
@property
@override
def parameters(self) -> Mapping[str, Any]:
return self._model_instance.parameters
@parameters.setter
@override
def parameters(self, value: Mapping[str, Any]) -> None:
self._model_instance.parameters = value
@property
@override
def stop(self) -> Sequence[str] | None:
return self._model_instance.stop
@override
def get_model_schema(self) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema(
self._model_instance.model_name,
@ -179,6 +186,7 @@ class DifyPreparedLLM(LLMProtocol):
raise ValueError(f"Model schema not found for {self._model_instance.model_name}")
return model_schema
@override
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
@ -204,6 +212,7 @@ class DifyPreparedLLM(LLMProtocol):
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
@override
def invoke_llm(
self,
*,
@ -243,6 +252,7 @@ class DifyPreparedLLM(LLMProtocol):
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@override
def invoke_llm_with_structured_output(
self,
*,
@ -263,11 +273,13 @@ class DifyPreparedLLM(LLMProtocol):
stream=stream,
)
@override
def is_structured_output_parse_error(self, error: Exception) -> bool:
return isinstance(error, OutputParserError)
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
@override
def serialize(
self,
*,
@ -294,6 +306,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
self._file_reference_factory = file_reference_factory
self._segment_access_checker = segment_access_checker
@override
def load(self, *, segment_id: str) -> Sequence[File]:
if not is_retriever_segment_access_granted(segment_id):
return []
@ -341,6 +354,7 @@ class DifyToolFileManager(ToolFileManagerProtocol):
self._manager = ToolFileManager()
self._conversation_id_getter = conversation_id_getter
@override
def create_file_by_raw(
self,
*,
@ -358,6 +372,7 @@ class DifyToolFileManager(ToolFileManagerProtocol):
filename=filename,
)
@override
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
return self._manager.get_file_generator_by_tool_file_id(tool_file_id)
@ -394,9 +409,11 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
def file_reference_factory(self) -> FileReferenceFactoryProtocol:
return self._file_reference_factory
@override
def build_file_reference(self, *, mapping: Mapping[str, Any]):
return self._file_reference_factory.build_from_mapping(mapping=mapping)
@override
def get_runtime(
self,
*,
@ -447,6 +464,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
)
)
@override
def get_runtime_parameters(
self,
*,
@ -458,6 +476,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
for parameter in (tool.get_merged_runtime_parameters() or [])
]
@override
def invoke(
self,
*,
@ -503,6 +522,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
return self._adapt_messages(transformed_messages, provider_name=provider_name)
@override
def get_usage(
self,
*,
@ -745,6 +765,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
form_repository=form_repository,
)
@override
def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None:
repo = self.build_form_repository()
return repo.get_form(node_id)
@ -766,6 +787,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
)
return restored_data
@override
def create_form(
self,
*,

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
@ -56,9 +56,11 @@ class AgentNode(Node[AgentNodeData]):
self._message_transformer = message_transformer
@classmethod
@override
def version(cls) -> str:
return "1"
@override
def populate_start_event(self, event) -> None:
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
event.extras["agent_strategy"] = {
@ -69,6 +71,7 @@ class AgentNode(Node[AgentNodeData]):
),
}
@override
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
@ -167,6 +170,7 @@ class AgentNode(Node[AgentNodeData]):
)
@classmethod
@override
def _extract_variable_selector_to_variable_mapping(
cls,
*,

View File

@ -1,11 +1,14 @@
from __future__ import annotations
from typing import override
from factories.agent_factory import get_plugin_agent_strategy
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
class PluginAgentStrategyResolver(AgentStrategyResolver):
@override
def resolve(
self,
*,
@ -21,6 +24,7 @@ class PluginAgentStrategyResolver(AgentStrategyResolver):
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
@override
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
from core.plugin.impl.plugin import PluginInstaller

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
from agenton.compositor import CompositorSessionSnapshot
@ -101,12 +101,15 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
self._session_store = session_store
@classmethod
@override
def version(cls) -> str:
return "2"
@override
def populate_start_event(self, event) -> None:
event.extras["agent_node"] = {"version": "2", "agent_node_kind": self.node_data.agent_node_kind}
@override
def _run(self) -> Generator[NodeEventBase, None, None]:
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
workflow_id = self.graph_init_params.workflow_id
@ -577,6 +580,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
metadata["agent_backend"] = agent_backend
@classmethod
@override
def _extract_variable_selector_to_variable_mapping(
cls,
*,

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.datasource.datasource_manager import DatasourceManager
@ -49,10 +49,12 @@ class DatasourceNode(Node[DatasourceNodeData]):
)
self.datasource_manager = DatasourceManager
@override
def populate_start_event(self, event) -> None:
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
event.provider_type = self.node_data.provider_type
@override
def _run(self) -> Generator:
"""
Run the datasource node
@ -183,6 +185,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
)
@classmethod
@override
def _extract_variable_selector_to_variable_mapping(
cls,
*,
@ -219,5 +222,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
return result
@classmethod
@override
def version(cls) -> str:
return "1"

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
@ -46,6 +46,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()
@override
def _run(self) -> NodeRunResult:
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
@ -145,6 +146,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
return rst
@classmethod
@override
def version(cls) -> str:
return "1"

View File

@ -6,7 +6,7 @@ the workflow node registry.
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, override
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
@ -87,9 +87,11 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
self._rag_retrieval = DatasetRetrieval()
@classmethod
@override
def version(cls):
return "1"
@override
def _run(self) -> NodeRunResult:
usage = LLMUsage.empty_usage()
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
@ -327,6 +329,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
@classmethod
@override
def _extract_variable_selector_to_variable_mapping(
cls,
*,

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
@ -15,6 +15,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
execution_type = NodeExecutionType.ROOT
@classmethod
@override
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "plugin",
@ -30,12 +31,15 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
}
@classmethod
@override
def version(cls) -> str:
return "1"
@override
def populate_start_event(self, event) -> None:
event.provider_id = self.node_data.provider_id
@override
def _run(self) -> NodeRunResult:
"""
Run the plugin trigger node.

View File

@ -1,4 +1,5 @@
from collections.abc import Mapping
from typing import override
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
@ -14,10 +15,12 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
execution_type = NodeExecutionType.ROOT
@classmethod
@override
def version(cls) -> str:
return "1"
@classmethod
@override
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": TRIGGER_SCHEDULE_NODE_TYPE,
@ -29,6 +32,7 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
},
}
@override
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id))
system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from core.workflow.file_reference import resolve_file_record_id
@ -25,12 +25,14 @@ class TriggerWebhookNode(Node[WebhookData]):
_file_reference_factory: FileReferenceFactoryProtocol
@override
def post_init(self) -> None:
from core.workflow.node_runtime import DifyFileReferenceFactory
self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
@classmethod
@override
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "webhook",
@ -48,9 +50,11 @@ class TriggerWebhookNode(Node[WebhookData]):
}
@classmethod
@override
def version(cls) -> str:
return "1"
@override
def _run(self) -> NodeRunResult:
"""
Run the webhook node.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@ -11,6 +11,7 @@ from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderErr
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Sandbox-backed Jinja2 renderer for workflow-owned node composition."""
@override
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = CodeExecutor.execute_workflow_code_template(

View File

@ -1,3 +1,5 @@
from typing import override
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
@ -52,6 +54,7 @@ class CorrelationIdGenerator(IdGenerator):
parent-child linking), otherwise random
"""
@override
def generate_trace_id(self) -> int:
correlation_id = _correlation_id_context.get()
if correlation_id:
@ -61,6 +64,7 @@ class CorrelationIdGenerator(IdGenerator):
pass
return random.getrandbits(128)
@override
def generate_span_id(self) -> int:
source = _span_id_source_context.get()
if source:

View File

@ -1,4 +1,5 @@
import json
from typing import override
from flask_restx import fields
@ -7,6 +8,7 @@ from libs.helper import AppIconUrlField, TimestampField
class JsonStringField(fields.Raw):
@override
def format(self, value):
if isinstance(value, str):
try:

View File

@ -1,9 +1,12 @@
from typing import override
from flask_restx import fields
from graphon.file import File
class FilesContainedField(fields.Raw):
@override
def format(self, value):
return self._format_file_object(value)

View File

@ -1,3 +1,5 @@
from typing import override
from flask_restx import fields
from core.helper import encrypter
@ -11,6 +13,7 @@ ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER,
class EnvironmentVariableField(fields.Raw):
@override
def format(self, value):
# Mask secret variables values in environment_variables
if isinstance(value, SecretVariable):

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import override
from uuid import uuid4
from sqlalchemy import DateTime, func
@ -51,6 +52,7 @@ class DefaultFieldsMixin:
onupdate=func.current_timestamp(),
)
@override
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"
@ -87,6 +89,7 @@ class DefaultFieldsDCMixin(MappedAsDataclass):
onupdate=func.current_timestamp(),
)
@override
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"

View File

@ -9,7 +9,7 @@ import re
import time
from datetime import datetime
from json import JSONDecodeError
from typing import Any, ClassVar, TypedDict, cast
from typing import Any, ClassVar, TypedDict, cast, override
from uuid import uuid4
import sqlalchemy as sa
@ -1790,5 +1790,6 @@ class DocumentSegmentSummary(TypeBase):
init=False,
)
@override
def __repr__(self):
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"

View File

@ -1,4 +1,5 @@
from enum import StrEnum
from typing import override
from core.trigger.constants import (
TRIGGER_PLUGIN_NODE_TYPE,
@ -12,6 +13,7 @@ class CreatorUserRole(StrEnum):
END_USER = "end_user"
@classmethod
@override
def _missing_(cls, value):
if value == "end-user":
return cls.END_USER

View File

@ -8,7 +8,7 @@ from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast, override
from uuid import uuid4
import sqlalchemy as sa
@ -2058,10 +2058,12 @@ class EndUser(Base, UserMixin):
)
@property
@override
def is_anonymous(self) -> Literal[False]:
return False
@is_anonymous.setter
@override
def is_anonymous(self, value: bool) -> None:
self._is_anonymous = value

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import datetime
from enum import StrEnum, auto
from functools import cached_property
from typing import override
from uuid import uuid4
import sqlalchemy as sa
@ -73,6 +74,7 @@ class Provider(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
@override
def __repr__(self):
return (
f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}',"

View File

@ -1,3 +1,5 @@
from typing import override
"""Provider ID entities for plugin system."""
import re
@ -14,6 +16,7 @@ class GenericProviderID:
def to_string(self) -> str:
return str(self)
@override
def __str__(self) -> str:
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Sequence
from typing import override
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
@ -74,6 +75,7 @@ class AliyunDataTrace(BaseTraceInstance):
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
@override
def trace(self, trace_info: BaseTraceInfo):
match trace_info:
case WorkflowTraceInfo():

View File

@ -6,7 +6,7 @@ import traceback
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Protocol, Union, cast
from typing import Any, Protocol, Union, cast, override
from urllib.parse import urlparse
from openinference.semconv.trace import (
@ -730,6 +730,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.root_span_carriers: dict[str, dict[str, str]] = {}
self.carrier: dict[str, str] = {}
@override
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from typing import Any, cast
from typing import Any, cast, override
from unittest.mock import MagicMock, patch
import dify_trace_arize_phoenix.arize_phoenix_trace as arize_phoenix_trace_module
@ -133,10 +133,12 @@ class _CollectingSpanExporter(SpanExporter):
def __init__(self):
self.spans: list[ReadableSpan] = []
@override
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
self.spans.extend(spans)
return SpanExportResult.SUCCESS
@override
def shutdown(self) -> None:
return None
@ -258,9 +260,11 @@ def test_set_span_status():
# repr branch
class SilentError:
@override
def __str__(self):
return ""
@override
def __repr__(self):
return "SilentErrorRepr"

View File

@ -2,6 +2,7 @@ import logging
import os
import uuid
from datetime import UTC, datetime, timedelta
from typing import override
from langfuse import Langfuse
from langfuse.api import (
@ -106,6 +107,7 @@ class LangFuseDataTrace(BaseTraceInstance):
return start_time + timedelta(seconds=ttft_seconds)
@override
def trace(self, trace_info: BaseTraceInfo):
match trace_info:
case WorkflowTraceInfo():

View File

@ -2,6 +2,7 @@ import collections
import logging
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from typing import override
from unittest.mock import MagicMock
import pytest
@ -738,6 +739,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
node.status = "succeeded"
class BadDict(collections.UserDict):
@override
def get(self, key, default=None):
if key == "usage":
raise Exception("Usage extraction failed")

View File

@ -2,7 +2,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import cast
from typing import cast, override
from langsmith import Client
from langsmith.schemas import RunBase
@ -47,6 +47,7 @@ class LangSmithDataTrace(BaseTraceInstance):
self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@override
def trace(self, trace_info: BaseTraceInfo):
match trace_info:
case WorkflowTraceInfo():

View File

@ -1,5 +1,6 @@
import collections
from datetime import datetime, timedelta
from typing import override
from unittest.mock import MagicMock
import pytest
@ -549,6 +550,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pyte
)
class BadDict(collections.UserDict):
@override
def get(self, key, default=None):
if key == "usage":
raise Exception("Usage extraction failed")

View File

@ -1,7 +1,7 @@
import logging
import os
from datetime import datetime, timedelta
from typing import Any, cast
from typing import Any, cast, override
import mlflow
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
@ -86,6 +86,7 @@ class MLflowDataTrace(BaseTraceInstance):
self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces"
@override
def trace(self, trace_info: BaseTraceInfo):
"""Simple dispatch to trace methods"""
try:

View File

@ -3,7 +3,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Any, cast
from typing import Any, cast, override
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
@ -95,6 +95,7 @@ class OpikDataTrace(BaseTraceInstance):
self.project = opik_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@override
def trace(self, trace_info: BaseTraceInfo):
match trace_info:
case WorkflowTraceInfo():

View File

@ -2,6 +2,7 @@ import collections
import logging
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from typing import override
from unittest.mock import MagicMock
import pytest
@ -643,6 +644,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch
node.status = "succeeded"
class BadDict(collections.UserDict):
@override
def get(self, key, default=None):
if key == "usage":
raise Exception("Usage extraction failed")

View File

@ -1,3 +1,5 @@
from typing import override
"""Tencent APM tracing with idempotent client cleanup."""
import inspect
@ -56,6 +58,7 @@ class TencentDataTrace(BaseTraceInstance):
metrics_export_interval_sec=5,
)
@override
def trace(self, trace_info: BaseTraceInfo) -> None:
"""Main tracing entry point - coordinates different trace types."""
match trace_info:

View File

@ -2,7 +2,7 @@ import logging
import os
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any, cast
from typing import Any, cast, override
import wandb
import weave
@ -78,6 +78,7 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug("Weave get run url failed: %s", str(e))
raise ValueError(f"Weave get run url failed: {str(e)}")
@override
def trace(self, trace_info: BaseTraceInfo):
logger.debug("Trace info: %s", trace_info)
match trace_info:

View File

@ -3,7 +3,7 @@ import json
import logging
import uuid
from contextlib import contextmanager
from typing import Any, Literal, cast
from typing import Any, Literal, cast, override
import mysql.connector
from mysql.connector import Error as MySQLError
@ -81,6 +81,7 @@ class AlibabaCloudMySQLVector(BaseVector):
self.hnsw_m = config.hnsw_m
self._check_vector_support()
@override
def get_type(self) -> str:
return VectorType.ALIBABACLOUD_MYSQL
@ -135,11 +136,13 @@ class AlibabaCloudMySQLVector(BaseVector):
cur.close()
conn.close()
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
@ -165,6 +168,7 @@ class AlibabaCloudMySQLVector(BaseVector):
cur.executemany(insert_sql, values)
return pks
@override
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
@ -183,6 +187,7 @@ class AlibabaCloudMySQLVector(BaseVector):
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
@override
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists
if not ids:
@ -199,12 +204,14 @@ class AlibabaCloudMySQLVector(BaseVector):
else:
raise e
@override
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
)
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
@ -274,6 +281,7 @@ class AlibabaCloudMySQLVector(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
@ -308,6 +316,7 @@ class AlibabaCloudMySQLVector(BaseVector):
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
@override
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
@ -355,6 +364,7 @@ class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
return cast(Literal["cosine", "euclidean"], distance_function)
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,5 +1,6 @@
import json
import unittest
from typing import override
from unittest.mock import MagicMock, patch
import pytest
@ -22,6 +23,7 @@ except ImportError:
class TestAlibabaCloudMySQLVector(unittest.TestCase):
@override
def setUp(self):
self.config = AlibabaCloudMySQLVectorConfig(
host="localhost",

View File

@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, override
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@ -32,38 +32,48 @@ class AnalyticdbVector(BaseVector):
raise ValueError("Either api_config or sql_config must be provided")
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
@override
def get_type(self) -> str:
return VectorType.ANALYTICDB
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self.analyticdb_vector.create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
self.analyticdb_vector.add_texts(documents, embeddings)
return []
@override
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
@override
def delete_by_ids(self, ids: list[str]):
self.analyticdb_vector.delete_by_ids(ids)
@override
def delete_by_metadata_field(self, key: str, value: str):
self.analyticdb_vector.delete_by_metadata_field(key, value)
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
@override
def delete(self):
self.analyticdb_vector.delete()
class AnalyticdbVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,3 +1,5 @@
from typing import override
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
@ -40,6 +42,7 @@ class AnalyticdbVectorTest(AbstractVectorTest):
),
)
@override
def run_all_tests(self):
self.vector.delete()
return super().run_all_tests()

View File

@ -2,7 +2,7 @@ import json
import logging
import time
import uuid
from typing import Any
from typing import Any, override
import numpy as np
from pydantic import BaseModel, model_validator
@ -82,6 +82,7 @@ class BaiduVector(BaseVector):
self._client = self._init_client(config)
self._db = self._init_database()
@override
def get_type(self) -> str:
return VectorType.BAIDU
@ -92,10 +93,12 @@ class BaiduVector(BaseVector):
}
return result
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))
self.add_texts(texts, embeddings)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
total_count = len(documents)
batch_size = 1000
@ -116,23 +119,27 @@ class BaiduVector(BaseVector):
rows.append(row)
table.upsert(rows=rows)
@override
def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
if res and res.code == 0:
return True
return False
@override
def delete_by_ids(self, ids: list[str]):
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
@override
def delete_by_metadata_field(self, key: str, value: str):
# Escape double quotes in value to prevent injection
escaped_value = value.replace('"', '\\"')
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
document_ids_filter = kwargs.get("document_ids_filter")
@ -154,6 +161,7 @@ class BaiduVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# document ids filter
document_ids_filter = kwargs.get("document_ids_filter")
@ -189,6 +197,7 @@ class BaiduVector(BaseVector):
docs.append(doc)
return docs
@override
def delete(self):
try:
self._db.drop_table(table_name=self._collection_name)
@ -368,6 +377,7 @@ class BaiduVector(BaseVector):
class BaiduVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,3 +1,5 @@
from typing import override
from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
@ -18,10 +20,12 @@ class BaiduVectorTest(AbstractVectorTest):
),
)
@override
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
@override
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0

View File

@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict
from typing import Any, TypedDict, override
import chromadb
from chromadb import QueryResult, Settings
@ -55,9 +55,11 @@ class ChromaVector(BaseVector):
self._client_config = config
self._client = chromadb.HttpClient(**self._client_config.to_chroma_params())
@override
def get_type(self) -> str:
return VectorType.CHROMA
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# create collection
@ -74,6 +76,7 @@ class ChromaVector(BaseVector):
self._client.get_or_create_collection(collection_name)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
@ -84,25 +87,30 @@ class ChromaVector(BaseVector):
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
return uuids
@override
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
# FIXME: fix the type error later
collection.delete(where={key: {"$eq": value}}) # type: ignore
@override
def delete(self):
self._client.delete_collection(self._collection_name)
@override
def delete_by_ids(self, ids: list[str]):
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids)
@override
def text_exists(self, id: str) -> bool:
collection = self._client.get_or_create_collection(self._collection_name)
response = collection.get(ids=[id])
return len(response) > 0
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
document_ids_filter = kwargs.get("document_ids_filter")
@ -142,12 +150,14 @@ class ChromaVector(BaseVector):
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# chroma does not support BM25 full text searching
return []
class ChromaVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,3 +1,5 @@
from typing import override
import chromadb
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
@ -22,6 +24,7 @@ class ChromaVectorTest(AbstractVectorTest):
),
)
@override
def search_by_full_text(self):
# chroma dos not support full text searching
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())

View File

@ -8,7 +8,7 @@ import re
import threading
import time
import uuid
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@ -436,6 +436,7 @@ class ClickzettaVector(BaseVector):
raise result
return result
@override
def get_type(self) -> str:
"""Return the vector database type."""
return "clickzetta"
@ -469,6 +470,7 @@ class ClickzettaVector(BaseVector):
)
return False
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create the collection and add initial documents."""
# Execute table creation through write queue to avoid concurrent conflicts
@ -606,6 +608,7 @@ class ClickzettaVector(BaseVector):
logger.warning("Failed to create inverted index: %s", e)
# Continue without inverted index - full-text search will fall back to LIKE
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
"""Add documents with embeddings to the collection."""
if not documents:
@ -716,6 +719,7 @@ class ClickzettaVector(BaseVector):
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
raise
@override
def text_exists(self, id: str) -> bool:
"""Check if a document exists by ID."""
# Check if table exists first
@ -732,6 +736,7 @@ class ClickzettaVector(BaseVector):
result = cursor.fetchone()
return result[0] > 0 if result else False
@override
def delete_by_ids(self, ids: list[str]):
"""Delete documents by IDs."""
if not ids:
@ -757,6 +762,7 @@ class ClickzettaVector(BaseVector):
with connection.cursor() as cursor:
cursor.execute(sql, binding_params=safe_ids)
@override
def delete_by_metadata_field(self, key: str, value: str):
"""Delete documents by metadata field."""
# Check if table exists before attempting delete
@ -780,6 +786,7 @@ class ClickzettaVector(BaseVector):
)
cursor.execute(sql, binding_params=[value])
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
# Check if table exists first
@ -865,6 +872,7 @@ class ClickzettaVector(BaseVector):
return documents
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search for documents using full-text search with inverted index."""
if not self._config.enable_inverted_index:
@ -1031,6 +1039,7 @@ class ClickzettaVector(BaseVector):
return documents
@override
def delete(self):
"""Delete the entire collection."""
with self.get_connection_context() as connection:
@ -1057,6 +1066,7 @@ class ClickzettaVector(BaseVector):
class ClickzettaVectorFactory(AbstractVectorFactory):
"""Factory for creating Clickzetta vector instances."""
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
"""Initialize a Clickzetta vector instance."""
# Get configuration from environment variables or dataset config

View File

@ -3,7 +3,7 @@ import logging
import time
import uuid
from datetime import timedelta
from typing import Any
from typing import Any, override
from couchbase import search # type: ignore
from couchbase.auth import PasswordAuthenticator # type: ignore
@ -68,6 +68,7 @@ class CouchbaseVector(BaseVector):
# Wait until the cluster is ready for use.
self._cluster.wait_until_ready(timedelta(seconds=5))
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_id = str(uuid.uuid4()).replace("-", "")
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
@ -200,9 +201,11 @@ class CouchbaseVector(BaseVector):
# Check if the collection exists in the scope
return self._collection_name in scope_collection_map[self._scope_name]
@override
def get_type(self) -> str:
return VectorType.COUCHBASE
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
@ -221,6 +224,7 @@ class CouchbaseVector(BaseVector):
return doc_ids
@override
def text_exists(self, id: str) -> bool:
# Use a parameterized query for safety and correctness
query = f"""
@ -234,6 +238,7 @@ class CouchbaseVector(BaseVector):
return bool(row["count"] > 0)
return False # Return False if no rows are returned
@override
def delete_by_ids(self, ids: list[str]):
query = f"""
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
@ -261,6 +266,7 @@ class CouchbaseVector(BaseVector):
# result = self._cluster.query(query, named_parameters={'value':value})
# return [row['id'] for row in result.rows()]
@override
def delete_by_metadata_field(self, key: str, value: str):
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
@ -268,6 +274,7 @@ class CouchbaseVector(BaseVector):
"""
self._cluster.query(query, named_parameters={"value": value}).execute()
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0
@ -303,6 +310,7 @@ class CouchbaseVector(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
@ -325,6 +333,7 @@ class CouchbaseVector(BaseVector):
return docs
@override
def delete(self):
manager = self._bucket.collections()
scopes = manager.get_all_scopes()
@ -356,6 +365,7 @@ class CouchbaseVector(BaseVector):
class CouchbaseVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,6 +1,7 @@
import logging
import subprocess
import time
from typing import override
from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
@ -40,6 +41,7 @@ class CouchbaseTest(AbstractVectorTest):
),
)
@override
def search_by_vector(self):
# brief sleep to ensure document is indexed
time.sleep(5)

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, override
from flask import current_app
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
@override
def create_collection(
self,
embeddings: list[list[float]],
@ -82,6 +83,7 @@ class ElasticSearchJaVector(ElasticSearchVector):
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, cast
from typing import Any, cast, override
from urllib.parse import urlparse
from elasticsearch import ConnectionError as ElasticsearchConnectionError
@ -154,9 +154,11 @@ class ElasticSearchVector(BaseVector):
if parse_version(self._version) < parse_version("8.0.0"):
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
@override
def get_type(self) -> str:
return VectorType.ELASTICSEARCH
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
@ -172,15 +174,18 @@ class ElasticSearchVector(BaseVector):
self._client.indices.refresh(index=self._collection_name)
return uuids
@override
def text_exists(self, id: str) -> bool:
return bool(self._client.exists(index=self._collection_name, id=id))
@override
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)
@override
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
@ -188,9 +193,11 @@ class ElasticSearchVector(BaseVector):
if ids:
self.delete_by_ids(ids)
@override
def delete(self):
self._client.indices.delete(index=self._collection_name)
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
@ -224,6 +231,7 @@ class ElasticSearchVector(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
document_ids_filter = kwargs.get("document_ids_filter")
@ -249,6 +257,7 @@ class ElasticSearchVector(BaseVector):
return docs
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
@ -297,6 +306,7 @@ class ElasticSearchVector(BaseVector):
class ElasticSearchVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any, cast
from typing import Any, cast, override
import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
@ -81,15 +81,18 @@ class HologresVector(BaseVector):
client.connect()
return client
@override
def get_type(self) -> str:
return VectorType.HOLOGRES
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create collection table with vector and full-text indexes, then add texts."""
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add texts with embeddings to the collection using batch upsert."""
if not documents:
@ -127,6 +130,7 @@ class HologresVector(BaseVector):
return pks
@override
def text_exists(self, id: str) -> bool:
"""Check if a text with the given doc_id exists in the collection."""
if not self._client.check_table_exist(self.table_name):
@ -140,6 +144,7 @@ class HologresVector(BaseVector):
)
return bool(result)
@override
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
"""Get document IDs by metadata field key and value."""
result = self._client.execute(
@ -152,6 +157,7 @@ class HologresVector(BaseVector):
return [row[0] for row in result]
return None
@override
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their doc_id list."""
if not ids:
@ -166,6 +172,7 @@ class HologresVector(BaseVector):
)
)
@override
def delete_by_metadata_field(self, key: str, value: str):
"""Delete documents by metadata field key and value."""
if not self._client.check_table_exist(self.table_name):
@ -177,6 +184,7 @@ class HologresVector(BaseVector):
)
)
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
if not self._client.check_table_exist(self.table_name):
@ -229,6 +237,7 @@ class HologresVector(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search for documents by full-text search."""
if not self._client.check_table_exist(self.table_name):
@ -272,6 +281,7 @@ class HologresVector(BaseVector):
return docs
@override
def delete(self):
"""Delete the entire collection table."""
if self._client.check_table_exist(self.table_name):
@ -333,6 +343,7 @@ class HologresVector(BaseVector):
class HologresVectorFactory(AbstractVectorFactory):
"""Factory class for creating HologresVector instances."""
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,6 +1,6 @@
import os
import uuid
from typing import cast
from typing import cast, override
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
@ -35,6 +35,7 @@ class HologresVectorTest(AbstractVectorTest):
),
)
@override
def search_by_full_text(self):
"""Override: full-text index may not be immediately ready in real mode."""
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
@ -97,12 +98,14 @@ class HologresVectorTest(AbstractVectorTest):
assert len(hits) == 1
assert hits[0].metadata["doc_id"] == self.example_doc_id
@override
def get_ids_by_metadata_field(self):
"""Override: Hologres implements this method via JSONB query."""
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
@override
def run_all_tests(self):
# Clean up before running tests
self.vector.delete()

View File

@ -1,7 +1,7 @@
import json
import logging
import ssl
from typing import Any
from typing import Any, override
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
@ -68,9 +68,11 @@ class HuaweiCloudVector(BaseVector):
super().__init__(index_name.lower())
self._client = Elasticsearch(**config.to_elasticsearch_params())
@override
def get_type(self) -> str:
return VectorType.HUAWEI_CLOUD
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
@ -86,15 +88,18 @@ class HuaweiCloudVector(BaseVector):
self._client.indices.refresh(index=self._collection_name)
return uuids
@override
def text_exists(self, id: str) -> bool:
return bool(self._client.exists(index=self._collection_name, id=id))
@override
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)
@override
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
@ -102,9 +107,11 @@ class HuaweiCloudVector(BaseVector):
if ids:
self.delete_by_ids(ids)
@override
def delete(self):
self._client.indices.delete(index=self._collection_name)
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
@ -145,6 +152,7 @@ class HuaweiCloudVector(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY: query}}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
@ -160,6 +168,7 @@ class HuaweiCloudVector(BaseVector):
return docs
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
@ -207,6 +216,7 @@ class HuaweiCloudVector(BaseVector):
class HuaweiCloudVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HuaweiCloudVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,3 +1,5 @@
from typing import override
from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
@ -15,10 +17,12 @@ class HuaweiCloudVectorTest(AbstractVectorTest):
),
)
@override
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 3
@override
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 3

View File

@ -11,7 +11,7 @@ import logging
import threading
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override
from configs import dify_config
from configs.middleware.vdb.iris_config import IrisVectorConfig
@ -188,6 +188,7 @@ class IrisVector(BaseVector):
self.schema = config.IRIS_SCHEMA or "dify"
self.pool = get_iris_pool(config)
@override
def get_type(self) -> str:
return VectorType.IRIS
@ -206,11 +207,13 @@ class IrisVector(BaseVector):
cursor.close()
self.pool.return_connection(conn)
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
"""Add documents with embeddings to the collection."""
added_ids = []
@ -226,6 +229,7 @@ class IrisVector(BaseVector):
return added_ids
@override
def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
try:
with self._get_cursor() as cursor:
@ -235,6 +239,7 @@ class IrisVector(BaseVector):
except (OSError, RuntimeError, ValueError):
return False
@override
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
@ -244,6 +249,7 @@ class IrisVector(BaseVector):
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
cursor.execute(sql, ids)
@override
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Delete documents by metadata field (JSON LIKE pattern matching)."""
with self._get_cursor() as cursor:
@ -251,6 +257,7 @@ class IrisVector(BaseVector):
sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
cursor.execute(sql, (pattern,))
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search similar documents using VECTOR_COSINE with HNSW index."""
top_k = kwargs.get("top_k", 4)
@ -275,6 +282,7 @@ class IrisVector(BaseVector):
docs.append(Document(page_content=text, metadata=metadata))
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search documents by full-text using iFind index with BM25 relevance scoring.
@ -404,6 +412,7 @@ class IrisVector(BaseVector):
return docs
@override
def delete(self) -> None:
"""Delete the entire collection (drop table - permanent)."""
with self._get_cursor() as cursor:
@ -481,6 +490,7 @@ class IrisVector(BaseVector):
class IrisVectorFactory(AbstractVectorFactory):
"""Factory for creating IrisVector instances."""
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, override
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
@ -79,9 +79,11 @@ class LindormVectorStore(BaseVector):
self._using_ugc = using_ugc
self.kwargs = kwargs
@override
def get_type(self) -> str:
return VectorType.LINDORM
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
@ -90,6 +92,7 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
@override
def add_texts(
self,
documents: list[Document],
@ -156,6 +159,7 @@ class LindormVectorStore(BaseVector):
logger.exception("Failed to process batch %s", batch_num + 1)
raise
@override
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
@ -168,11 +172,13 @@ class LindormVectorStore(BaseVector):
else:
return None
@override
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
@override
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their IDs in batch.
@ -219,6 +225,7 @@ class LindormVectorStore(BaseVector):
else:
logger.exception("Error deleting document: %s", error)
@override
def delete(self):
if self._using_ugc:
routing_filter_query = {
@ -233,6 +240,7 @@ class LindormVectorStore(BaseVector):
else:
logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name)
@override
def text_exists(self, id: str) -> bool:
try:
params: dict[str, Any] = {}
@ -243,6 +251,7 @@ class LindormVectorStore(BaseVector):
except:
return False
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
@ -305,6 +314,7 @@ class LindormVectorStore(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
filters = []
@ -377,6 +387,7 @@ class LindormVectorStore(BaseVector):
class LindormVectorStoreFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,

View File

@ -1,4 +1,5 @@
import os
from typing import override
from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
@ -26,6 +27,7 @@ class TestLindormVectorStore(AbstractVectorTest):
),
)
@override
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
@ -47,6 +49,7 @@ class TestLindormVectorStoreUGC(AbstractVectorTest):
routing_value=self.collection_name,
)
@override
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None

View File

@ -3,7 +3,7 @@ import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate
from typing import Any, Concatenate, override
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@ -69,16 +69,20 @@ class MatrixoneVector(BaseVector):
self.client = None
@property
@override
def collection_name(self):
return self._collection_name
@collection_name.setter
@override
def collection_name(self, value):
self._collection_name = value
@override
def get_type(self) -> str:
return VectorType.MATRIXONE
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
@ -108,6 +112,7 @@ class MatrixoneVector(BaseVector):
logger.exception("Failed to create full text index")
return client
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
@ -126,12 +131,14 @@ class MatrixoneVector(BaseVector):
return ids
@ensure_client
@override
def text_exists(self, id: str) -> bool:
assert self.client is not None
result = self.client.get(ids=[id])
return len(result) > 0
@ensure_client
@override
def delete_by_ids(self, ids: list[str]):
assert self.client is not None
if not ids:
@ -139,17 +146,20 @@ class MatrixoneVector(BaseVector):
self.client.delete(ids=ids)
@ensure_client
@override
def get_ids_by_metadata_field(self, key: str, value: str):
assert self.client is not None
results = self.client.query_by_metadata(filter={key: value})
return [result.id for result in results]
@ensure_client
@override
def delete_by_metadata_field(self, key: str, value: str):
assert self.client is not None
self.client.delete(filter={key: value})
@ensure_client
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
@ -177,6 +187,7 @@ class MatrixoneVector(BaseVector):
return docs
@ensure_client
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
@ -207,12 +218,14 @@ class MatrixoneVector(BaseVector):
return docs
@ensure_client
@override
def delete(self):
assert self.client is not None
self.client.delete()
class MatrixoneVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,3 +1,5 @@
from typing import override
from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from core.rag.datasource.vdb.vector_integration_test_support import (
@ -15,6 +17,7 @@ class MatrixoneVectorTest(AbstractVectorTest):
),
)
@override
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict, cast
from typing import Any, TypedDict, cast, override
from packaging import version
from pydantic import BaseModel, model_validator
@ -120,12 +120,14 @@ class MilvusVector(BaseVector):
logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e))
return False
@override
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
@ -135,6 +137,7 @@ class MilvusVector(BaseVector):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
@ -164,6 +167,7 @@ class MilvusVector(BaseVector):
raise e
return pks
@override
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
@ -176,6 +180,7 @@ class MilvusVector(BaseVector):
else:
return None
@override
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
@ -185,6 +190,7 @@ class MilvusVector(BaseVector):
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
@override
def delete_by_ids(self, ids: list[str]):
"""
Delete documents by their IDs.
@ -197,6 +203,7 @@ class MilvusVector(BaseVector):
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
@override
def delete(self):
"""
Delete the entire collection.
@ -204,6 +211,7 @@ class MilvusVector(BaseVector):
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
@override
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
@ -245,6 +253,7 @@ class MilvusVector(BaseVector):
return docs
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search for documents by vector similarity.
@ -269,6 +278,7 @@ class MilvusVector(BaseVector):
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""
Search for documents by full-text search (if hybrid search is enabled).
@ -411,6 +421,7 @@ class MilvusVectorFactory(AbstractVectorFactory):
Factory class for creating MilvusVector instances.
"""
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.

View File

@ -1,3 +1,5 @@
from typing import override
from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector
from core.rag.datasource.vdb.vector_integration_test_support import (
@ -18,11 +20,13 @@ class MilvusVectorTest(AbstractVectorTest):
),
)
@override
def search_by_full_text(self):
# milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) >= 0
@override
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1

View File

@ -2,7 +2,7 @@ import json
import logging
import uuid
from enum import StrEnum
from typing import Any
from typing import Any, override
from clickhouse_connect import get_client # type: ignore[import-untyped]
from pydantic import BaseModel
@ -46,9 +46,11 @@ class MyScaleVector(BaseVector):
)
self._client.command("SET allow_experimental_object_type=1")
@override
def get_type(self) -> str:
return VectorType.MYSCALE
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
@ -71,6 +73,7 @@ class MyScaleVector(BaseVector):
"""
self._client.command(sql)
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = []
columns = ["id", "text", "vector", "metadata"]
@ -97,10 +100,12 @@ class MyScaleVector(BaseVector):
def escape_str(value: Any) -> str:
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
@override
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
return results.row_count > 0
@override
def delete_by_ids(self, ids: list[str]):
if not ids:
return
@ -108,20 +113,24 @@ class MyScaleVector(BaseVector):
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)
@override
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
).result_rows
return [row[0] for row in rows]
@override
def delete_by_metadata_field(self, key: str, value: str):
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
)
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
@ -156,11 +165,13 @@ class MyScaleVector(BaseVector):
logger.exception("Vector search operation failed")
return []
@override
def delete(self):
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
class MyScaleVectorFactory(AbstractVectorFactory):
@override
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]

View File

@ -1,3 +1,5 @@
from typing import override
from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector
from core.rag.datasource.vdb.vector_integration_test_support import (
@ -20,6 +22,7 @@ class MyScaleVectorTest(AbstractVectorTest):
),
)
@override
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1

View File

@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Any, Literal
from typing import Any, Literal, override
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_distance # type: ignore
@ -86,6 +86,7 @@ class OceanBaseVector(BaseVector):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
@override
def get_type(self) -> str:
return VectorType.OCEANBASE
@ -114,6 +115,7 @@ class OceanBaseVector(BaseVector):
"""
return field in self._fields
@override
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
@ -237,6 +239,7 @@ class OceanBaseVector(BaseVector):
logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e))
return False
@override
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
batch_size = self._config.batch_size
@ -283,6 +286,7 @@ class OceanBaseVector(BaseVector):
self._collection_name,
)
@override
def text_exists(self, id: str) -> bool:
try:
cur = self._client.get(table_name=self._collection_name, ids=id)
@ -295,6 +299,7 @@ class OceanBaseVector(BaseVector):
)
raise Exception(f"Failed to check text existence for id '{id}'") from e
@override
def delete_by_ids(self, ids: list[str]):
if not ids:
return
@ -309,6 +314,7 @@ class OceanBaseVector(BaseVector):
)
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
@override
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
try:
import re
@ -343,6 +349,7 @@ class OceanBaseVector(BaseVector):
)
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
@override
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
@ -381,6 +388,7 @@ class OceanBaseVector(BaseVector):
return docs
@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
if not self._hybrid_search_enabled:
logger.warning(
@ -438,6 +446,7 @@ class OceanBaseVector(BaseVector):
)
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
@override
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from sqlalchemy import text
@ -508,6 +517,7 @@ class OceanBaseVector(BaseVector):
return -distance
raise ValueError(f"Unsupported metric_type '{metric}'")
@override
def delete(self):
try:
self._client.drop_table_if_exist(self._collection_name)
@ -518,6 +528,7 @@ class OceanBaseVector(BaseVector):
class OceanBaseVectorFactory(AbstractVectorFactory):
@override
def init_vector(
self,
dataset: Dataset,

View File

@ -1,3 +1,5 @@
from typing import override
import pytest
from dify_vdb_oceanbase.oceanbase_vector import (
OceanBaseVector,
@ -30,6 +32,7 @@ class OceanBaseVectorTest(AbstractVectorTest):
super().__init__()
self.vector = vector
@override
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1

Some files were not shown because too many files have changed in this diff Show More