refactor(api): type OpenSearch/Lindorm/Huawei VDB config params dicts with TypedDicts (#34870)

This commit is contained in:
dataCenter430 2026-04-09 17:34:34 -07:00 committed by GitHub
parent a31c1d2c69
commit c5c5c71d15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 24 deletions

View File

@ -5,6 +5,7 @@ from typing import Any
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -19,6 +20,16 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
hosts: list[str]
verify_certs: bool
ssl_show_warn: bool
request_timeout: int
retry_on_timeout: bool
max_retries: int
basic_auth: tuple[str, str]
def create_ssl_context() -> ssl.SSLContext: def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False ssl_context.check_hostname = False
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
raise ValueError("config HOSTS is required") raise ValueError("config HOSTS is required")
return values return values
def to_elasticsearch_params(self) -> dict[str, Any]: def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
params = { params = HuaweiElasticsearchParamsDict(
"hosts": self.hosts.split(","), hosts=self.hosts.split(","),
"verify_certs": False, verify_certs=False,
"ssl_show_warn": False, ssl_show_warn=False,
"request_timeout": 30000, request_timeout=30000,
"retry_on_timeout": True, retry_on_timeout=True,
"max_retries": 10, max_retries=10,
} )
if self.username and self.password: if self.username and self.password:
params["basic_auth"] = (self.username, self.password) params["basic_auth"] = (self.username, self.password)
return params return params

View File

@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index" UGC_INDEX_PREFIX = "ugc_index"
class LindormOpenSearchParamsDict(TypedDict, total=False):
hosts: str | None
use_ssl: bool
pool_maxsize: int
timeout: int
http_auth: tuple[str, str]
class LindormVectorStoreConfig(BaseModel): class LindormVectorStoreConfig(BaseModel):
hosts: str | None hosts: str | None
username: str | None = None username: str | None = None
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
raise ValueError("config PASSWORD is required") raise ValueError("config PASSWORD is required")
return values return values
def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
params: dict[str, Any] = { params = LindormOpenSearchParamsDict(
"hosts": self.hosts, hosts=self.hosts,
"use_ssl": False, use_ssl=False,
"pool_maxsize": 128, pool_maxsize=128,
"timeout": 30, timeout=30,
} )
if self.username and self.password: if self.username and self.password:
params["http_auth"] = (self.username, self.password) params["http_auth"] = (self.username, self.password)
return params return params

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod from configs.middleware.vdb.opensearch_config import AuthMethod
@ -21,6 +22,20 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _OpenSearchHostDict(TypedDict):
host: str
port: int
class OpenSearchParamsDict(TypedDict, total=False):
hosts: list[_OpenSearchHostDict]
use_ssl: bool
verify_certs: bool
connection_class: type
pool_maxsize: int
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
class OpenSearchConfig(BaseModel): class OpenSearchConfig(BaseModel):
host: str host: str
port: int port: int
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
service=self.aws_service, # type: ignore[arg-type] service=self.aws_service, # type: ignore[arg-type]
) )
def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> OpenSearchParamsDict:
params = { params = OpenSearchParamsDict(
"hosts": [{"host": self.host, "port": self.port}], hosts=[{"host": self.host, "port": self.port}],
"use_ssl": self.secure, use_ssl=self.secure,
"verify_certs": self.verify_certs, verify_certs=self.verify_certs,
"connection_class": Urllib3HttpConnection, connection_class=Urllib3HttpConnection,
"pool_maxsize": 20, pool_maxsize=20,
} )
if self.auth_method == "basic": if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB") logger.info("Using basic authentication for OpenSearch Vector DB")