refactor(api): type VDB config params dicts with TypedDicts (#34677)

This commit is contained in:
Statxc 2026-04-07 10:23:42 -03:00 committed by GitHub
parent 485fc2c416
commit e645cbd8f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any
from typing import Any, TypedDict
from packaging import version
from pydantic import BaseModel, model_validator
@ -20,6 +20,15 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MilvusParamsDict(TypedDict):
uri: str
token: str | None
user: str | None
password: str | None
db_name: str
analyzer_params: str | None
class MilvusConfig(BaseModel):
"""
Configuration class for Milvus connection.
@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
def to_milvus_params(self) -> MilvusParamsDict:
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
result: MilvusParamsDict = {
"uri": self.uri,
"token": self.token,
"user": self.user,
@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
"db_name": self.database,
"analyzer_params": self.analyzer_params,
}
return result
class MilvusVector(BaseVector):

View File

@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any
from typing import Any, TypedDict
from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore
@ -23,6 +23,13 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class TencentParamsDict(TypedDict):
url: str
username: str | None
key: str | None
timeout: float
class TencentConfig(BaseModel):
url: str
api_key: str | None = None
@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
def to_tencent_params(self) -> TencentParamsDict:
result: TencentParamsDict = {
"url": self.url,
"username": self.username,
"key": self.api_key,
"timeout": self.timeout,
}
return result
bm25 = BM25Encoder.default("zh")