diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 2b220fc04d..99ab0d82f2 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, from configs import dify_config from core.rag.datasource.vdb.field import Field as VDBField from core.rag.datasource.vdb.field import parse_metadata_json -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -85,8 +85,12 @@ class BaiduVector(BaseVector): def get_type(self) -> str: return VectorType.BAIDU - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self._create_table(len(embeddings[0])) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index a9f946dd43..f4fcb975c3 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -22,7 +22,7 @@ from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -94,8 +94,12 @@ class QdrantVector(BaseVector): def get_type(self) -> str: return VectorType.QDRANT - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 829db9db20..c6836d9cf9 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted from configs import dify_config from core.rag.datasource.vdb.field import parse_metadata_json -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -83,8 +83,12 @@ class TencentVector(BaseVector): def get_type(self) -> str: return VectorType.TENCENT - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def _has_collection(self) -> bool: return bool( diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 499a48ac76..605cc5a08f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -25,7 +25,7 @@ from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector): def get_type(self) -> str: return VectorType.TIDB_ON_QDRANT - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index f29b270e40..6fbd802a10 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -1,11 +1,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypedDict from core.rag.models.document import Document +class VectorStoreDict(TypedDict): + class_prefix: str + + +class VectorIndexStructDict(TypedDict): + type: str + vector_store: VectorStoreDict + + class BaseVector(ABC): def __init__(self, collection_name: str): self._collection_name = collection_name diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index d29d62c93f..25b65b82a9 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError from configs import dify_config from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -184,9 +184,13 @@ class WeaviateVector(BaseVector): dataset_id = dataset.id return Dataset.gen_collection_name_by_id(dataset_id) - def to_index_struct(self) -> dict: + def to_index_struct(self) -> VectorIndexStructDict: """Returns the index structure dictionary for persistence.""" - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """