diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index c72797c34a..7cdb2d3a99 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, TypedDict from packaging import version from pydantic import BaseModel, model_validator @@ -20,6 +20,15 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class MilvusParamsDict(TypedDict): + uri: str + token: str | None + user: str | None + password: str | None + db_name: str + analyzer_params: str | None + + class MilvusConfig(BaseModel): """ Configuration class for Milvus connection. @@ -50,11 +59,11 @@ class MilvusConfig(BaseModel): raise ValueError("config MILVUS_PASSWORD is required") return values - def to_milvus_params(self): + def to_milvus_params(self) -> MilvusParamsDict: """ Convert the configuration to a dictionary of Milvus connection parameters. """ - return { + result: MilvusParamsDict = { "uri": self.uri, "token": self.token, "user": self.user, @@ -62,6 +71,7 @@ class MilvusConfig(BaseModel): "db_name": self.database, "analyzer_params": self.analyzer_params, } + return result class MilvusVector(BaseVector): diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index c6836d9cf9..2f26d6fff3 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel from tcvdb_text.encoder import BM25Encoder # type: ignore @@ -23,6 +23,13 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class TencentParamsDict(TypedDict): + url: str + username: str | None + key: str | None + timeout: float + + class TencentConfig(BaseModel): url: str api_key: str | None = None @@ -36,8 +43,14 @@ class TencentConfig(BaseModel): max_upsert_batch_size: int = 128 enable_hybrid_search: bool = False # Flag to enable hybrid search - def to_tencent_params(self): - return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} + def to_tencent_params(self) -> TencentParamsDict: + result: TencentParamsDict = { + "url": self.url, + "username": self.username, + "key": self.api_key, + "timeout": self.timeout, + } + return result bm25 = BM25Encoder.default("zh")