feat(api): support explicit TLS for Milvus vector store (#36265)

This commit is contained in:
Amr Sherif 2026-05-27 09:17:27 +04:00 committed by GitHub
parent cab215e209
commit bee21c9f86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 64 additions and 8 deletions

View File

@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings):
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)
MILVUS_SECURE: bool = Field(
description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS "
"and verifies the server certificate. Equivalent to passing secure=True to pymilvus.",
default=False,
)
MILVUS_SERVER_PEM_PATH: str | None = Field(
description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via "
"a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.",
default=None,
)
MILVUS_SERVER_NAME: str | None = Field(
description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. "
"Required when MILVUS_SERVER_PEM_PATH is set.",
default=None,
)

View File

@ -42,6 +42,9 @@ class MilvusConfig(BaseModel):
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: str | None = None # Analyzer params
secure: bool = False # Enable one-way TLS to Milvus
server_pem_path: str | None = None # Path to server certificate (PEM) for TLS verification
server_name: str | None = None # Server name to verify against the certificate (SNI / CN)
@model_validator(mode="before")
@classmethod
@ -388,16 +391,19 @@ class MilvusVector(BaseVector):
"""
Initialize and return a Milvus client.
"""
kwargs: dict[str, Any] = {"uri": config.uri, "db_name": config.database}
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
kwargs["token"] = config.token
else:
client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client
kwargs["user"] = config.user or ""
kwargs["password"] = config.password or ""
if config.secure:
kwargs["secure"] = True
if config.server_pem_path:
kwargs["server_pem_path"] = config.server_pem_path
if config.server_name:
kwargs["server_name"] = config.server_name
return MilvusClient(**kwargs)
class MilvusVectorFactory(AbstractVectorFactory):
@ -427,5 +433,8 @@ class MilvusVectorFactory(AbstractVectorFactory):
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
secure=dify_config.MILVUS_SECURE,
server_pem_path=dify_config.MILVUS_SERVER_PEM_PATH,
server_name=dify_config.MILVUS_SERVER_NAME,
),
)

View File

@ -163,6 +163,35 @@ def test_init_client_supports_token_and_user_password(milvus_module):
assert user_client.init_kwargs["password"] == "Milvus"
def test_init_client_passes_tls_kwargs_when_secure(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
client = vector._init_client(
milvus_module.MilvusConfig.model_validate(
{
"uri": "https://milvus.example.com:19530",
"token": "abc",
"database": "db",
"secure": True,
"server_pem_path": "/etc/milvus/certs/server.pem",
"server_name": "milvus.example.com",
}
)
)
assert client.init_kwargs["secure"] is True
assert client.init_kwargs["server_pem_path"] == "/etc/milvus/certs/server.pem"
assert client.init_kwargs["server_name"] == "milvus.example.com"
def test_init_client_omits_tls_kwargs_when_not_secure(milvus_module):
vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector)
client = vector._init_client(
milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"})
)
assert "secure" not in client.init_kwargs
assert "server_pem_path" not in client.init_kwargs
assert "server_name" not in client.init_kwargs
def test_init_loads_fields_when_collection_exists(milvus_module):
client = milvus_module.MilvusClient(uri="http://localhost:19530")
client.has_collection.return_value = True