mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor(api): type VDB config params dicts with TypedDicts (#34677)
This commit is contained in:
parent
485fc2c416
commit
e645cbd8f8
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -20,6 +20,15 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusParamsDict(TypedDict):
|
||||
uri: str
|
||||
token: str | None
|
||||
user: str | None
|
||||
password: str | None
|
||||
db_name: str
|
||||
analyzer_params: str | None
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Milvus connection.
|
||||
@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
def to_milvus_params(self) -> MilvusParamsDict:
|
||||
"""
|
||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||
"""
|
||||
return {
|
||||
result: MilvusParamsDict = {
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
"user": self.user,
|
||||
@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
|
||||
"db_name": self.database,
|
||||
"analyzer_params": self.analyzer_params,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
||||
@ -23,6 +23,13 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentParamsDict(TypedDict):
|
||||
url: str
|
||||
username: str | None
|
||||
key: str | None
|
||||
timeout: float
|
||||
|
||||
|
||||
class TencentConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
|
||||
max_upsert_batch_size: int = 128
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
def to_tencent_params(self) -> TencentParamsDict:
|
||||
result: TencentParamsDict = {
|
||||
"url": self.url,
|
||||
"username": self.username,
|
||||
"key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
bm25 = BM25Encoder.default("zh")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user