mirror of
https://github.com/langgenius/dify.git
synced 2026-04-18 12:28:32 +08:00
refactor(api): type OpenSearch/Lindorm/Huawei VDB config params dicts with TypedDicts (#34870)
This commit is contained in:
parent
a31c1d2c69
commit
c5c5c71d15
@ -5,6 +5,7 @@ from typing import Any
|
|||||||
|
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
@ -19,6 +20,16 @@ from models.dataset import Dataset
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
|
||||||
|
hosts: list[str]
|
||||||
|
verify_certs: bool
|
||||||
|
ssl_show_warn: bool
|
||||||
|
request_timeout: int
|
||||||
|
retry_on_timeout: bool
|
||||||
|
max_retries: int
|
||||||
|
basic_auth: tuple[str, str]
|
||||||
|
|
||||||
|
|
||||||
def create_ssl_context() -> ssl.SSLContext:
|
def create_ssl_context() -> ssl.SSLContext:
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
ssl_context.check_hostname = False
|
ssl_context.check_hostname = False
|
||||||
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
|
|||||||
raise ValueError("config HOSTS is required")
|
raise ValueError("config HOSTS is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def to_elasticsearch_params(self) -> dict[str, Any]:
|
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
|
||||||
params = {
|
params = HuaweiElasticsearchParamsDict(
|
||||||
"hosts": self.hosts.split(","),
|
hosts=self.hosts.split(","),
|
||||||
"verify_certs": False,
|
verify_certs=False,
|
||||||
"ssl_show_warn": False,
|
ssl_show_warn=False,
|
||||||
"request_timeout": 30000,
|
request_timeout=30000,
|
||||||
"retry_on_timeout": True,
|
retry_on_timeout=True,
|
||||||
"max_retries": 10,
|
max_retries=10,
|
||||||
}
|
)
|
||||||
if self.username and self.password:
|
if self.username and self.password:
|
||||||
params["basic_auth"] = (self.username, self.password)
|
params["basic_auth"] = (self.username, self.password)
|
||||||
return params
|
return params
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
|
|||||||
from opensearchpy.helpers import BulkIndexError
|
from opensearchpy.helpers import BulkIndexError
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
|
|||||||
UGC_INDEX_PREFIX = "ugc_index"
|
UGC_INDEX_PREFIX = "ugc_index"
|
||||||
|
|
||||||
|
|
||||||
|
class LindormOpenSearchParamsDict(TypedDict, total=False):
|
||||||
|
hosts: str | None
|
||||||
|
use_ssl: bool
|
||||||
|
pool_maxsize: int
|
||||||
|
timeout: int
|
||||||
|
http_auth: tuple[str, str]
|
||||||
|
|
||||||
|
|
||||||
class LindormVectorStoreConfig(BaseModel):
|
class LindormVectorStoreConfig(BaseModel):
|
||||||
hosts: str | None
|
hosts: str | None
|
||||||
username: str | None = None
|
username: str | None = None
|
||||||
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
|
|||||||
raise ValueError("config PASSWORD is required")
|
raise ValueError("config PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def to_opensearch_params(self) -> dict[str, Any]:
|
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
|
||||||
params: dict[str, Any] = {
|
params = LindormOpenSearchParamsDict(
|
||||||
"hosts": self.hosts,
|
hosts=self.hosts,
|
||||||
"use_ssl": False,
|
use_ssl=False,
|
||||||
"pool_maxsize": 128,
|
pool_maxsize=128,
|
||||||
"timeout": 30,
|
timeout=30,
|
||||||
}
|
)
|
||||||
if self.username and self.password:
|
if self.username and self.password:
|
||||||
params["http_auth"] = (self.username, self.password)
|
params["http_auth"] = (self.username, self.password)
|
||||||
return params
|
return params
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
|||||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||||
from opensearchpy.helpers import BulkIndexError
|
from opensearchpy.helpers import BulkIndexError
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||||
@ -21,6 +22,20 @@ from models.dataset import Dataset
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _OpenSearchHostDict(TypedDict):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
|
||||||
|
|
||||||
|
class OpenSearchParamsDict(TypedDict, total=False):
|
||||||
|
hosts: list[_OpenSearchHostDict]
|
||||||
|
use_ssl: bool
|
||||||
|
verify_certs: bool
|
||||||
|
connection_class: type
|
||||||
|
pool_maxsize: int
|
||||||
|
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
|
||||||
|
|
||||||
|
|
||||||
class OpenSearchConfig(BaseModel):
|
class OpenSearchConfig(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
|
|||||||
service=self.aws_service, # type: ignore[arg-type]
|
service=self.aws_service, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_opensearch_params(self) -> dict[str, Any]:
|
def to_opensearch_params(self) -> OpenSearchParamsDict:
|
||||||
params = {
|
params = OpenSearchParamsDict(
|
||||||
"hosts": [{"host": self.host, "port": self.port}],
|
hosts=[{"host": self.host, "port": self.port}],
|
||||||
"use_ssl": self.secure,
|
use_ssl=self.secure,
|
||||||
"verify_certs": self.verify_certs,
|
verify_certs=self.verify_certs,
|
||||||
"connection_class": Urllib3HttpConnection,
|
connection_class=Urllib3HttpConnection,
|
||||||
"pool_maxsize": 20,
|
pool_maxsize=20,
|
||||||
}
|
)
|
||||||
|
|
||||||
if self.auth_method == "basic":
|
if self.auth_method == "basic":
|
||||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user