mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 12:37:20 +08:00
feat(api): Making WeaviateClient a singleton
Co-authored-by: lijiezhao <lijiezhao@perfect99.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
29b724e23d
commit
3920d67b8e
@ -8,6 +8,7 @@ document embeddings used in retrieval-augmented generation workflows.
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -32,6 +33,9 @@ from models.dataset import Dataset
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_weaviate_client: weaviate.WeaviateClient | None = None
|
||||||
|
_weaviate_client_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
class WeaviateConfig(BaseModel):
|
class WeaviateConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -99,43 +103,52 @@ class WeaviateVector(BaseVector):
|
|||||||
|
|
||||||
Configures both HTTP and gRPC connections with proper authentication.
|
Configures both HTTP and gRPC connections with proper authentication.
|
||||||
"""
|
"""
|
||||||
p = urlparse(config.endpoint)
|
global _weaviate_client
|
||||||
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
if _weaviate_client and _weaviate_client.is_ready():
|
||||||
http_secure = p.scheme == "https"
|
return _weaviate_client
|
||||||
http_port = p.port or (443 if http_secure else 80)
|
|
||||||
|
|
||||||
# Parse gRPC configuration
|
with _weaviate_client_lock:
|
||||||
if config.grpc_endpoint:
|
if _weaviate_client and _weaviate_client.is_ready():
|
||||||
# Urls without scheme won't be parsed correctly in some python versions,
|
return _weaviate_client
|
||||||
# see https://bugs.python.org/issue27657
|
|
||||||
grpc_endpoint_with_scheme = (
|
p = urlparse(config.endpoint)
|
||||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
||||||
|
http_secure = p.scheme == "https"
|
||||||
|
http_port = p.port or (443 if http_secure else 80)
|
||||||
|
|
||||||
|
# Parse gRPC configuration
|
||||||
|
if config.grpc_endpoint:
|
||||||
|
# Urls without scheme won't be parsed correctly in some python versions,
|
||||||
|
# see https://bugs.python.org/issue27657
|
||||||
|
grpc_endpoint_with_scheme = (
|
||||||
|
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||||
|
)
|
||||||
|
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||||
|
grpc_host = grpc_p.hostname or "localhost"
|
||||||
|
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||||
|
grpc_secure = grpc_p.scheme == "grpcs"
|
||||||
|
else:
|
||||||
|
# Infer from HTTP endpoint as fallback
|
||||||
|
grpc_host = host
|
||||||
|
grpc_secure = http_secure
|
||||||
|
grpc_port = 443 if grpc_secure else 50051
|
||||||
|
|
||||||
|
client = weaviate.connect_to_custom(
|
||||||
|
http_host=host,
|
||||||
|
http_port=http_port,
|
||||||
|
http_secure=http_secure,
|
||||||
|
grpc_host=grpc_host,
|
||||||
|
grpc_port=grpc_port,
|
||||||
|
grpc_secure=grpc_secure,
|
||||||
|
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||||
|
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
||||||
)
|
)
|
||||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
|
||||||
grpc_host = grpc_p.hostname or "localhost"
|
|
||||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
|
||||||
grpc_secure = grpc_p.scheme == "grpcs"
|
|
||||||
else:
|
|
||||||
# Infer from HTTP endpoint as fallback
|
|
||||||
grpc_host = host
|
|
||||||
grpc_secure = http_secure
|
|
||||||
grpc_port = 443 if grpc_secure else 50051
|
|
||||||
|
|
||||||
client = weaviate.connect_to_custom(
|
if not client.is_ready():
|
||||||
http_host=host,
|
raise ConnectionError("Vector database is not ready")
|
||||||
http_port=http_port,
|
|
||||||
http_secure=http_secure,
|
|
||||||
grpc_host=grpc_host,
|
|
||||||
grpc_port=grpc_port,
|
|
||||||
grpc_secure=grpc_secure,
|
|
||||||
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
|
||||||
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client.is_ready():
|
_weaviate_client = client
|
||||||
raise ConnectionError("Vector database is not ready")
|
return client
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
"""Returns the vector database type identifier."""
|
"""Returns the vector database type identifier."""
|
||||||
|
|||||||
@ -0,0 +1,33 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_client_with_valid_config():
|
||||||
|
"""Test successful client initialization with valid configuration."""
|
||||||
|
config = WeaviateConfig(
|
||||||
|
endpoint="http://localhost:8080",
|
||||||
|
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("weaviate.connect_to_custom") as mock_connect:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.is_ready.return_value = True
|
||||||
|
mock_connect.return_value = mock_client
|
||||||
|
|
||||||
|
vector = WeaviateVector(
|
||||||
|
collection_name="test_collection",
|
||||||
|
config=config,
|
||||||
|
attributes=["doc_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vector._client == mock_client
|
||||||
|
mock_connect.assert_called_once()
|
||||||
|
call_kwargs = mock_connect.call_args[1]
|
||||||
|
assert call_kwargs["http_host"] == "localhost"
|
||||||
|
assert call_kwargs["http_port"] == 8080
|
||||||
|
assert call_kwargs["http_secure"] is False
|
||||||
|
assert call_kwargs["grpc_host"] == "localhost"
|
||||||
|
assert call_kwargs["grpc_port"] == 50051
|
||||||
|
assert call_kwargs["grpc_secure"] is False
|
||||||
|
assert call_kwargs["auth_credentials"] is not None
|
||||||
Loading…
Reference in New Issue
Block a user