diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index eb9b0ac2ab..2f3a3ed2bd 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings): description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.', default=None, ) + + MILVUS_SECURE: bool = Field( + description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS " + "and verifies the server certificate. Equivalent to passing secure=True to pymilvus.", + default=False, + ) + + MILVUS_SERVER_PEM_PATH: str | None = Field( + description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via " + "a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.", + default=None, + ) + + MILVUS_SERVER_NAME: str | None = Field( + description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. " + "Required when MILVUS_SERVER_PEM_PATH is set.", + default=None, + ) diff --git a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 823b877707..ac47be7a37 100644 --- a/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -42,6 +42,9 @@ class MilvusConfig(BaseModel): database: str = "default" # Database name enable_hybrid_search: bool = False # Flag to enable hybrid search analyzer_params: str | None = None # Analyzer params + secure: bool = False # Enable one-way TLS to Milvus + server_pem_path: str | None = None # Path to server certificate (PEM) for TLS verification + server_name: str | None = None # Server name to verify against the certificate (SNI / CN) @model_validator(mode="before") @classmethod @@ -388,16 +391,19 @@ class MilvusVector(BaseVector): """ Initialize and return a Milvus client. """ + kwargs: dict[str, Any] = {"uri": config.uri, "db_name": config.database} if config.token: - client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database) + kwargs["token"] = config.token else: - client = MilvusClient( - uri=config.uri, - user=config.user or "", - password=config.password or "", - db_name=config.database, - ) - return client + kwargs["user"] = config.user or "" + kwargs["password"] = config.password or "" + if config.secure: + kwargs["secure"] = True + if config.server_pem_path: + kwargs["server_pem_path"] = config.server_pem_path + if config.server_name: + kwargs["server_name"] = config.server_name + return MilvusClient(**kwargs) class MilvusVectorFactory(AbstractVectorFactory): @@ -427,5 +433,8 @@ class MilvusVectorFactory(AbstractVectorFactory): database=dify_config.MILVUS_DATABASE or "", enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "", + secure=dify_config.MILVUS_SECURE, + server_pem_path=dify_config.MILVUS_SERVER_PEM_PATH, + server_name=dify_config.MILVUS_SERVER_NAME, ), ) diff --git a/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py index 730ff9f296..028842a7d6 100644 --- a/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py @@ -163,6 +163,35 @@ def test_init_client_supports_token_and_user_password(milvus_module): assert user_client.init_kwargs["password"] == "Milvus" +def test_init_client_passes_tls_kwargs_when_secure(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + client = vector._init_client( + milvus_module.MilvusConfig.model_validate( + { + "uri": "https://milvus.example.com:19530", + "token": "abc", + "database": "db", + "secure": True, + "server_pem_path": "/etc/milvus/certs/server.pem", + "server_name": "milvus.example.com", + } + ) + ) + assert client.init_kwargs["secure"] is True + assert client.init_kwargs["server_pem_path"] == "/etc/milvus/certs/server.pem" + assert client.init_kwargs["server_name"] == "milvus.example.com" + + +def test_init_client_omits_tls_kwargs_when_not_secure(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert "secure" not in client.init_kwargs + assert "server_pem_path" not in client.init_kwargs + assert "server_name" not in client.init_kwargs + + def test_init_loads_fields_when_collection_exists(milvus_module): client = milvus_module.MilvusClient(uri="http://localhost:19530") client.has_collection.return_value = True