refactor: move vdb implementations to workspaces (#34900)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
This commit is contained in:
Yunlu Wen 2026-04-13 16:56:43 +08:00 committed by GitHub
parent c34f67495c
commit ae898652b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
223 changed files with 2009 additions and 984 deletions

View File

@ -92,6 +92,7 @@ jobs:
vdb: vdb:
- 'api/core/rag/datasource/**' - 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**' - 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml' - '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh' - '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example' - 'docker/.env.example'

View File

@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB) # - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py # run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB) # - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py # run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: | run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \ uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \ api/providers/vdb/vdb-chroma/tests/integration_tests \
api/tests/integration_tests/vdb/pgvector \ api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/tests/integration_tests/vdb/qdrant \ api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/tests/integration_tests/vdb/weaviate api/providers/vdb/vdb-weaviate/tests/integration_tests

View File

@ -21,8 +21,9 @@ RUN apt-get update \
# for building gmpy2 # for building gmpy2
libmpfr-dev libmpc-dev libmpfr-dev libmpc-dev
# Install Python dependencies # Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev RUN uv sync --locked --no-dev
# production stage # production stage

View File

@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red")) click.echo(click.style("No dataset collection bindings found.", fg="red"))
return return
import qdrant_client import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings: for binding in bindings:
if dify_config.QDRANT_URL is None: if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.") raise ValueError("Qdrant URL is required.")

View File

@ -1,4 +1,3 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field from pydantic import Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
default="public", default="public",
) )
HOLOGRES_TOKENIZER: TokenizerType = Field( HOLOGRES_TOKENIZER: str = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').", description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba", default="jieba",
) )
HOLOGRES_DISTANCE_METHOD: DistanceType = Field( HOLOGRES_DISTANCE_METHOD: str = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').", description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine", default="Cosine",
) )
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field( HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').", description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq", default="rabitq",
) )

View File

@ -0,0 +1,87 @@
"""Vector store backend discovery.
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
declares third-party dependencies and registers ``importlib`` entry points in group
``dify.vector_backends`` (see each package's ``pyproject.toml``).
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
distribution; entry points take precedence when both exist.
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
"""
from __future__ import annotations
import importlib
import logging
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
logger = logging.getLogger(__name__)
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
# module_path:class_name — optional fallback when no distribution registers the backend.
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
def clear_vector_factory_cache() -> None:
"""Drop lazily loaded factories (for tests or plugin reload)."""
_VECTOR_FACTORY_CACHE.clear()
def _vector_backend_entry_points():
return entry_points().select(group="dify.vector_backends")
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
for ep in _vector_backend_entry_points():
if ep.name != vector_type:
continue
try:
loaded = ep.load()
except Exception:
logger.exception("Failed to load vector backend entry point %s", ep.name)
raise
return loaded # type: ignore[return-value]
return None
def _unsupported(vector_type: str) -> ValueError:
installed = sorted(ep.name for ep in _vector_backend_entry_points())
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
return ValueError(
f"Vector store {vector_type!r} is not supported.{available_msg} "
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
"or register a dify.vector_backends entry point."
)
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
if not target:
raise _unsupported(vector_type)
module_path, _, attr = target.partition(":")
module = importlib.import_module(module_path)
return getattr(module, attr) # type: ignore[no-any-return]
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
if vector_type in _VECTOR_FACTORY_CACHE:
return _VECTOR_FACTORY_CACHE[vector_type]
plugin_cls = _load_plugin_factory(vector_type)
if plugin_cls is not None:
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
return plugin_cls
cls = _load_builtin_factory(vector_type)
_VECTOR_FACTORY_CACHE[vector_type] = cls
return cls

View File

@ -9,6 +9,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.cached_embedding import CacheEmbedding
@ -85,137 +86,7 @@ class Vector:
@staticmethod @staticmethod
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
match vector_type: return get_vector_factory_class(vector_type)
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case VectorType.TABLESTORE:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
case VectorType.HUAWEI_CLOUD:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: list | None = None, **kwargs): def create(self, texts: list | None = None, **kwargs):
if texts: if texts:

View File

@ -1,10 +1,19 @@
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
:class:`AbstractVectorTest` and helper functions live here so package tests can import
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
``tests.*`` package.
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
auto-discovered by pytest for all package tests.
"""
import uuid import uuid
from unittest.mock import MagicMock
import pytest import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset from models.dataset import Dataset
@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
return doc return doc
@pytest.fixture
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest: class AbstractVectorTest:
vector: BaseVector
def __init__(self): def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4()) self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4()) self.example_doc_id = str(uuid.uuid4())

12
api/providers/README.md Normal file
View File

@ -0,0 +1,12 @@
# Providers
This directory holds **optional workspace packages** that plug into Difys API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions.
## Developing Providers
- [VDB Providers](vdb/README.md)
## Tests
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).

View File

@ -0,0 +1,58 @@
# VDB providers
This directory contains all VDB providers.
## Architecture
1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins.
2. **Each provider** (`api/providers/vdb/<backend>/`) implements those contracts and registers an entry point.
3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`).
### Interfaces
| Piece | Role |
|--------|----------|
| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. |
| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. |
| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. |
| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. |
The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation.
### Entry point name must match the vector type string
Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"``tidb_on_qdrant`).
In `pyproject.toml`:
```toml
[project.entry-points."dify.vector_backends"]
pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory"
```
The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`.
### How registration works
1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache.
2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**.
3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance).
4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**.
After you change a providers `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environments dist-info matches the project metadata.
### Package layout (VDB)
Each backend usually follows:
- `api/providers/vdb/<backend>/pyproject.toml` — project name `dify-vdb-<backend>`, dependencies, entry points.
- `api/providers/vdb/<backend>/src/dify_vdb_<python_package>/` — implementation (e.g. `PGVector`, `PGVectorFactory`).
See `vdb/pgvector/` as a reference implementation.
### Wiring a new backend into the API workspace
The API uses a **uv workspace** (`api/pyproject.toml`):
1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members.
2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`.
3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle.

View File

@ -0,0 +1,22 @@
from unittest.mock import MagicMock
import pytest
from extensions import ext_redis
@pytest.fixture(autouse=True)
def _init_mock_redis():
"""Ensure redis_client has a backing client so __getattr__ never raises."""
if ext_redis.redis_client._client is None:
ext_redis.redis_client.initialize(MagicMock())
@pytest.fixture
def setup_mock_redis(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None))
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock)

View File

@ -0,0 +1,13 @@
[project]
name = "dify-vdb-alibabacloud-mysql"
version = "0.0.1"
dependencies = [
"mysql-connector-python>=9.3.0",
]
description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)."
[project.entry-points."dify.vector_backends"]
alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,10 +1,9 @@
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
import pytest import pytest
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
def test_validate_distance_function_accepts_supported_values(): def test_validate_distance_function_accepts_supported_values():

View File

@ -3,11 +3,11 @@ import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import (
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVector, AlibabaCloudMySQLVector,
AlibabaCloudMySQLVectorConfig, AlibabaCloudMySQLVectorConfig,
) )
from core.rag.models.document import Document from core.rag.models.document import Document
try: try:
@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
# Sample embeddings # Sample embeddings
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]] self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_init(self, mock_pool_class): def test_init(self, mock_pool_class):
"""Test AlibabaCloudMySQLVector initialization.""" """Test AlibabaCloudMySQLVector initialization."""
# Mock the connection pool # Mock the connection pool
@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert alibabacloud_mysql_vector.distance_function == "cosine" assert alibabacloud_mysql_vector.distance_function == "cosine"
assert alibabacloud_mysql_vector.pool is not None assert alibabacloud_mysql_vector.pool is not None
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_collection(self, mock_redis, mock_pool_class): def test_create_collection(self, mock_redis, mock_pool_class):
"""Test collection creation.""" """Test collection creation."""
# Mock Redis operations # Mock Redis operations
@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once() mock_redis.set.assert_called_once()
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_success(self, mock_pool_class): def test_vector_support_check_success(self, mock_pool_class):
"""Test successful vector support check.""" """Test successful vector support check."""
# Mock the connection pool # Mock the connection pool
@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert vector_store is not None assert vector_store is not None
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_failure(self, mock_pool_class): def test_vector_support_check_failure(self, mock_pool_class):
"""Test vector support check failure.""" """Test vector support check failure."""
# Mock the connection pool # Mock the connection pool
@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "RDS MySQL Vector functions are not available" in str(context.value) assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_function_error(self, mock_pool_class): def test_vector_support_check_function_error(self, mock_pool_class):
"""Test vector support check with function not found error.""" """Test vector support check with function not found error."""
# Mock the connection pool # Mock the connection pool
@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "RDS MySQL Vector functions are not available" in str(context.value) assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_documents(self, mock_redis, mock_pool_class): def test_create_documents(self, mock_redis, mock_pool_class):
"""Test creating documents with embeddings.""" """Test creating documents with embeddings."""
# Setup mocks # Setup mocks
@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "doc1" in result assert "doc1" in result
assert "doc2" in result assert "doc2" in result
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_add_texts(self, mock_pool_class): def test_add_texts(self, mock_pool_class):
"""Test adding texts to the vector store.""" """Test adding texts to the vector store."""
# Mock the connection pool # Mock the connection pool
@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(result) == 2 assert len(result) == 2
mock_cursor.executemany.assert_called_once() mock_cursor.executemany.assert_called_once()
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_exists(self, mock_pool_class): def test_text_exists(self, mock_pool_class):
"""Test checking if text exists.""" """Test checking if text exists."""
# Mock the connection pool # Mock the connection pool
@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "SELECT id FROM" in last_call[0][0] assert "SELECT id FROM" in last_call[0][0]
assert last_call[0][1] == ("doc1",) assert last_call[0][1] == ("doc1",)
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_not_exists(self, mock_pool_class): def test_text_not_exists(self, mock_pool_class):
"""Test checking if text does not exist.""" """Test checking if text does not exist."""
# Mock the connection pool # Mock the connection pool
@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert not exists assert not exists
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids(self, mock_pool_class): def test_get_by_ids(self, mock_pool_class):
"""Test getting documents by IDs.""" """Test getting documents by IDs."""
# Mock the connection pool # Mock the connection pool
@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert docs[0].page_content == "Test document 1" assert docs[0].page_content == "Test document 1"
assert docs[1].page_content == "Test document 2" assert docs[1].page_content == "Test document 2"
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids_empty_list(self, mock_pool_class): def test_get_by_ids_empty_list(self, mock_pool_class):
"""Test getting documents with empty ID list.""" """Test getting documents with empty ID list."""
# Mock the connection pool # Mock the connection pool
@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 0 assert len(docs) == 0
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids(self, mock_pool_class): def test_delete_by_ids(self, mock_pool_class):
"""Test deleting documents by IDs.""" """Test deleting documents by IDs."""
# Mock the connection pool # Mock the connection pool
@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "DELETE FROM" in delete_call[0][0] assert "DELETE FROM" in delete_call[0][0]
assert delete_call[0][1] == ["doc1", "doc2"] assert delete_call[0][1] == ["doc1", "doc2"]
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_empty_list(self, mock_pool_class): def test_delete_by_ids_empty_list(self, mock_pool_class):
"""Test deleting with empty ID list.""" """Test deleting with empty ID list."""
# Mock the connection pool # Mock the connection pool
@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
delete_calls = [call for call in execute_calls if "DELETE" in str(call)] delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 0 assert len(delete_calls) == 0
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_table_not_exists(self, mock_pool_class): def test_delete_by_ids_table_not_exists(self, mock_pool_class):
"""Test deleting when table doesn't exist.""" """Test deleting when table doesn't exist."""
# Mock the connection pool # Mock the connection pool
@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
# Should not raise an exception # Should not raise an exception
vector_store.delete_by_ids(["doc1"]) vector_store.delete_by_ids(["doc1"])
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_metadata_field(self, mock_pool_class): def test_delete_by_metadata_field(self, mock_pool_class):
"""Test deleting documents by metadata field.""" """Test deleting documents by metadata field."""
# Mock the connection pool # Mock the connection pool
@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0] assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
assert delete_call[0][1] == ("$.document_id", "dataset1") assert delete_call[0][1] == ("$.document_id", "dataset1")
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_cosine(self, mock_pool_class): def test_search_by_vector_cosine(self, mock_pool_class):
"""Test vector search with cosine distance.""" """Test vector search with cosine distance."""
# Mock the connection pool # Mock the connection pool
@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9 assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
assert docs[0].metadata["distance"] == 0.1 assert docs[0].metadata["distance"] == 0.1
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_euclidean(self, mock_pool_class): def test_search_by_vector_euclidean(self, mock_pool_class):
"""Test vector search with euclidean distance.""" """Test vector search with euclidean distance."""
config = AlibabaCloudMySQLVectorConfig( config = AlibabaCloudMySQLVectorConfig(
@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 1 assert len(docs) == 1
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3 assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_filter(self, mock_pool_class): def test_search_by_vector_with_filter(self, mock_pool_class):
"""Test vector search with document ID filter.""" """Test vector search with document ID filter."""
# Mock the connection pool # Mock the connection pool
@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
search_call = search_calls[0] search_call = search_calls[0]
assert "WHERE JSON_UNQUOTE" in search_call[0][0] assert "WHERE JSON_UNQUOTE" in search_call[0][0]
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_score_threshold(self, mock_pool_class): def test_search_by_vector_with_score_threshold(self, mock_pool_class):
"""Test vector search with score threshold.""" """Test vector search with score threshold."""
# Mock the connection pool # Mock the connection pool
@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 1 assert len(docs) == 1
assert docs[0].page_content == "High similarity document" assert docs[0].page_content == "High similarity document"
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_invalid_top_k(self, mock_pool_class): def test_search_by_vector_invalid_top_k(self, mock_pool_class):
"""Test vector search with invalid top_k.""" """Test vector search with invalid top_k."""
# Mock the connection pool # Mock the connection pool
@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k="invalid") vector_store.search_by_vector(query_vector, top_k="invalid")
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text(self, mock_pool_class): def test_search_by_full_text(self, mock_pool_class):
"""Test full-text search.""" """Test full-text search."""
# Mock the connection pool # Mock the connection pool
@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert docs[0].page_content == "This document contains machine learning content" assert docs[0].page_content == "This document contains machine learning content"
assert docs[0].metadata["score"] == 1.5 assert docs[0].metadata["score"] == 1.5
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_with_filter(self, mock_pool_class): def test_search_by_full_text_with_filter(self, mock_pool_class):
"""Test full-text search with document ID filter.""" """Test full-text search with document ID filter."""
# Mock the connection pool # Mock the connection pool
@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
search_call = search_calls[0] search_call = search_calls[0]
assert "AND JSON_UNQUOTE" in search_call[0][0] assert "AND JSON_UNQUOTE" in search_call[0][0]
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_invalid_top_k(self, mock_pool_class): def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
"""Test full-text search with invalid top_k.""" """Test full-text search with invalid top_k."""
# Mock the connection pool # Mock the connection pool
@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k="invalid") vector_store.search_by_full_text("test", top_k="invalid")
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_collection(self, mock_pool_class): def test_delete_collection(self, mock_pool_class):
"""Test deleting the entire collection.""" """Test deleting the entire collection."""
# Mock the connection pool # Mock the connection pool
@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
drop_call = drop_calls[0] drop_call = drop_calls[0]
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0] assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
@patch( @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_unsupported_distance_function(self, mock_pool_class): def test_unsupported_distance_function(self, mock_pool_class):
"""Test that Pydantic validation rejects unsupported distance functions.""" """Test that Pydantic validation rejects unsupported distance functions."""
# Test that creating config with unsupported distance function raises ValidationError # Test that creating config with unsupported distance function raises ValidationError

View File

@ -0,0 +1,15 @@
[project]
name = "dify-vdb-analyticdb"
version = "0.0.1"
dependencies = [
"alibabacloud_gpdb20160503~=5.2.0",
"alibabacloud_tea_openapi~=0.4.3",
"clickhouse-connect~=0.15.0",
]
description = "Dify vector store backend (dify-vdb-analyticdb)."
[project.entry-points."dify.vector_backends"]
analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -2,16 +2,16 @@ import json
from typing import Any from typing import Any
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document from core.rag.models.document import Document
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -1,9 +1,8 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
class AnalyticdbVectorTest(AbstractVectorTest): class AnalyticdbVectorTest(AbstractVectorTest):

View File

@ -1,12 +1,12 @@
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module
import pytest import pytest
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from core.rag.models.document import Document from core.rag.models.document import Document

View File

@ -4,13 +4,13 @@ import types
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module
import pytest import pytest
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI, AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig, AnalyticdbVectorOpenAPIConfig,
) )
from core.rag.models.document import Document from core.rag.models.document import Document

View File

@ -2,14 +2,14 @@ from contextlib import contextmanager
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module
import psycopg2.errors import psycopg2.errors
import pytest import pytest
from dify_vdb_analyticdb.analyticdb_vector_sql import (
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
AnalyticdbVectorBySql, AnalyticdbVectorBySql,
AnalyticdbVectorBySqlConfig, AnalyticdbVectorBySqlConfig,
) )
from core.rag.models.document import Document from core.rag.models.document import Document

View File

@ -0,0 +1,13 @@
[project]
name = "dify-vdb-baidu"
version = "0.0.1"
dependencies = [
"pymochow==2.4.0",
]
description = "Dify vector store backend (dify-vdb-baidu)."
[project.entry-points."dify.vector_backends"]
baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,10 +1,6 @@
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
pytest_plugins = ( from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.baiduvectordb",
)
class BaiduVectorTest(AbstractVectorTest): class BaiduVectorTest(AbstractVectorTest):

View File

@ -124,7 +124,7 @@ def _build_fake_pymochow_modules():
def baidu_module(monkeypatch): def baidu_module(monkeypatch):
for name, module in _build_fake_pymochow_modules().items(): for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.baidu.baidu_vector as module import dify_vdb_baidu.baidu_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,13 @@
[project]
name = "dify-vdb-chroma"
version = "0.0.1"
dependencies = [
"chromadb==0.5.20",
]
description = "Dify vector store backend (dify-vdb-chroma)."
[project.entry-points."dify.vector_backends"]
chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,13 +1,11 @@
import chromadb import chromadb
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector from core.rag.datasource.vdb.vector_integration_test_support import (
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest, AbstractVectorTest,
get_example_text, get_example_text,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class ChromaVectorTest(AbstractVectorTest): class ChromaVectorTest(AbstractVectorTest):
def __init__(self): def __init__(self):

View File

@ -47,7 +47,7 @@ def _build_fake_chroma_modules():
def chroma_module(monkeypatch): def chroma_module(monkeypatch):
fake_chroma = _build_fake_chroma_modules() fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import core.rag.datasource.vdb.chroma.chroma_vector as module import dify_vdb_chroma.chroma_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers:
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) - [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) - [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) - [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-clickzetta"
version = "0.0.1"
dependencies = [
"clickzetta-connector-python>=0.8.102",
]
description = "Dify vector store backend (dify-vdb-clickzetta)."
[project.entry-points."dify.vector_backends"]
clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -2,10 +2,10 @@ import contextlib
import os import os
import pytest import pytest
from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
from core.rag.models.document import Document from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class TestClickzettaVector(AbstractVectorTest): class TestClickzettaVector(AbstractVectorTest):
@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest):
""" """
@pytest.fixture @pytest.fixture
def vector_store(self): def vector_store(self, setup_mock_redis):
"""Create a Clickzetta vector store instance for testing.""" """Create a Clickzetta vector store instance for testing."""
# Skip test if Clickzetta credentials are not configured
if not os.getenv("CLICKZETTA_USERNAME"): if not os.getenv("CLICKZETTA_USERNAME"):
pytest.skip("CLICKZETTA_USERNAME is not configured") pytest.skip("CLICKZETTA_USERNAME is not configured")
if not os.getenv("CLICKZETTA_PASSWORD"): if not os.getenv("CLICKZETTA_PASSWORD"):
@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest):
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"), schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
batch_size=10, # Small batch size for testing batch_size=10,
enable_inverted_index=True, enable_inverted_index=True,
analyzer_type="chinese", analyzer_type="chinese",
analyzer_mode="smart", analyzer_mode="smart",
vector_distance_function="cosine_distance", vector_distance_function="cosine_distance",
) )
with setup_mock_redis(): vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
yield vector yield vector
# Cleanup: delete the test collection with contextlib.suppress(Exception):
with contextlib.suppress(Exception): vector.delete()
vector.delete()
def test_clickzetta_vector_basic_operations(self, vector_store): def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store.""" """Test basic CRUD operations on Clickzetta vector store."""

View File

@ -3,16 +3,19 @@
Test Clickzetta integration in Docker environment Test Clickzetta integration in Docker environment
""" """
import logging
import os import os
import time import time
import httpx import httpx
from clickzetta import connect from clickzetta import connect
logger = logging.getLogger(__name__)
def test_clickzetta_connection(): def test_clickzetta_connection():
"""Test direct connection to Clickzetta""" """Test direct connection to Clickzetta"""
print("=== Testing direct Clickzetta connection ===") logger.info("=== Testing direct Clickzetta connection ===")
try: try:
conn = connect( conn = connect(
username=os.getenv("CLICKZETTA_USERNAME", "test_user"), username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
@ -25,100 +28,93 @@ def test_clickzetta_connection():
) )
with conn.cursor() as cursor: with conn.cursor() as cursor:
# Test basic connectivity
cursor.execute("SELECT 1 as test") cursor.execute("SELECT 1 as test")
result = cursor.fetchone() result = cursor.fetchone()
print(f"✓ Connection test: {result}") logger.info("✓ Connection test: %s", result)
# Check if our test table exists
cursor.execute("SHOW TABLES IN dify") cursor.execute("SHOW TABLES IN dify")
tables = cursor.fetchall() tables = cursor.fetchall()
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}") logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"])
# Check if test collection exists
test_collection = "collection_test_dataset" test_collection = "collection_test_dataset"
if test_collection in [t[1] for t in tables if t[0] == "dify"]: if test_collection in [t[1] for t in tables if t[0] == "dify"]:
cursor.execute(f"DESCRIBE dify.{test_collection}") cursor.execute(f"DESCRIBE dify.{test_collection}")
columns = cursor.fetchall() columns = cursor.fetchall()
print(f"✓ Table structure for {test_collection}:") logger.info("✓ Table structure for %s:", test_collection)
for col in columns: for col in columns:
print(f" - {col[0]}: {col[1]}") logger.info(" - %s: %s", col[0], col[1])
# Check for indexes
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}") cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
indexes = cursor.fetchall() indexes = cursor.fetchall()
print(f"✓ Indexes on {test_collection}:") logger.info("✓ Indexes on %s:", test_collection)
for idx in indexes: for idx in indexes:
print(f" - {idx}") logger.info(" - %s", idx)
return True return True
except Exception as e: except Exception:
print(f"✗ Connection test failed: {e}") logger.exception("✗ Connection test failed")
return False return False
def test_dify_api(): def test_dify_api():
"""Test Dify API with Clickzetta backend""" """Test Dify API with Clickzetta backend"""
print("\n=== Testing Dify API ===") logger.info("\n=== Testing Dify API ===")
base_url = "http://localhost:5001" base_url = "http://localhost:5001"
# Wait for API to be ready
max_retries = 30 max_retries = 30
for i in range(max_retries): for i in range(max_retries):
try: try:
response = httpx.get(f"{base_url}/console/api/health") response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200: if response.status_code == 200:
print("✓ Dify API is ready") logger.info("✓ Dify API is ready")
break break
except: except:
if i == max_retries - 1: if i == max_retries - 1:
print("✗ Dify API is not responding") logger.exception("✗ Dify API is not responding")
return False return False
time.sleep(2) time.sleep(2)
# Check vector store configuration
try: try:
# This is a simplified check - in production, you'd use proper auth logger.info("✓ Dify is configured to use Clickzetta as vector store")
print("✓ Dify is configured to use Clickzetta as vector store")
return True return True
except Exception as e: except Exception:
print(f"✗ API test failed: {e}") logger.exception("✗ API test failed")
return False return False
def verify_table_structure(): def verify_table_structure():
"""Verify the table structure meets Dify requirements""" """Verify the table structure meets Dify requirements"""
print("\n=== Verifying Table Structure ===") logger.info("\n=== Verifying Table Structure ===")
expected_columns = { expected_columns = {
"id": "VARCHAR", "id": "VARCHAR",
"page_content": "VARCHAR", "page_content": "VARCHAR",
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta "metadata": "VARCHAR",
"vector": "ARRAY<FLOAT>", "vector": "ARRAY<FLOAT>",
} }
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
print("✓ Expected table structure:") logger.info("✓ Expected table structure:")
for col, dtype in expected_columns.items(): for col, dtype in expected_columns.items():
print(f" - {col}: {dtype}") logger.info(" - %s: %s", col, dtype)
print("\n✓ Required metadata fields:") logger.info("\n✓ Required metadata fields:")
for field in expected_metadata_fields: for field in expected_metadata_fields:
print(f" - {field}") logger.info(" - %s", field)
print("\n✓ Index requirements:") logger.info("\n✓ Index requirements:")
print(" - Vector index (HNSW) on 'vector' column") logger.info(" - Vector index (HNSW) on 'vector' column")
print(" - Full-text index on 'page_content' (optional)") logger.info(" - Full-text index on 'page_content' (optional)")
print(" - Functional index on metadata->>'$.doc_id' (recommended)") logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)")
print(" - Functional index on metadata->>'$.document_id' (recommended)") logger.info(" - Functional index on metadata->>'$.document_id' (recommended)")
return True return True
def main(): def main():
"""Run all tests""" """Run all tests"""
print("Starting Clickzetta integration tests for Dify Docker\n") logger.info("Starting Clickzetta integration tests for Dify Docker\n")
tests = [ tests = [
("Direct Clickzetta Connection", test_clickzetta_connection), ("Direct Clickzetta Connection", test_clickzetta_connection),
@ -131,33 +127,34 @@ def main():
try: try:
success = test_func() success = test_func()
results.append((test_name, success)) results.append((test_name, success))
except Exception as e: except Exception:
print(f"\n{test_name} crashed: {e}") logger.exception("\n%s crashed", test_name)
results.append((test_name, False)) results.append((test_name, False))
# Summary logger.info("\n%s", "=" * 50)
print("\n" + "=" * 50) logger.info("Test Summary:")
print("Test Summary:") logger.info("=" * 50)
print("=" * 50)
passed = sum(1 for _, success in results if success) passed = sum(1 for _, success in results if success)
total = len(results) total = len(results)
for test_name, success in results: for test_name, success in results:
status = "✅ PASSED" if success else "❌ FAILED" status = "✅ PASSED" if success else "❌ FAILED"
print(f"{test_name}: {status}") logger.info("%s: %s", test_name, status)
print(f"\nTotal: {passed}/{total} tests passed") logger.info("\nTotal: %s/%s tests passed", passed, total)
if passed == total: if passed == total:
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
print("\nNext steps:") logger.info("\nNext steps:")
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d") logger.info(
print("2. Access Dify at http://localhost:3000") "1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d"
print("3. Create a dataset and test vector storage with Clickzetta") )
logger.info("2. Access Dify at http://localhost:3000")
logger.info("3. Create a dataset and test vector storage with Clickzetta")
return 0 return 0
else: else:
print("\n⚠️ Some tests failed. Please check the errors above.") logger.error("\n⚠️ Some tests failed. Please check the errors above.")
return 1 return 1

View File

@ -47,7 +47,7 @@ def _build_fake_clickzetta_module():
@pytest.fixture @pytest.fixture
def clickzetta_module(monkeypatch): def clickzetta_module(monkeypatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module import dify_vdb_clickzetta.clickzetta_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-couchbase"
version = "0.0.1"
dependencies = [
"couchbase~=4.6.0",
]
description = "Dify vector store backend (dify-vdb-couchbase)."
[project.entry-points."dify.vector_backends"]
couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,12 +1,14 @@
import logging
import subprocess import subprocess
import time import time
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) logger = logging.getLogger(__name__)
def wait_for_healthy_container(service_name="couchbase-server", timeout=300): def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True ["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
) )
if result.stdout.strip() == "healthy": if result.stdout.strip() == "healthy":
print(f"{service_name} is healthy!") logger.info("%s is healthy!", service_name)
return True return True
else: else:
print(f"Waiting for {service_name} to be healthy...") logger.info("Waiting for %s to be healthy...", service_name)
time.sleep(10) time.sleep(10)
raise TimeoutError(f"{service_name} did not become healthy in time") raise TimeoutError(f"{service_name} did not become healthy in time")

View File

@ -154,7 +154,7 @@ def couchbase_module(monkeypatch):
for name, module in _build_fake_couchbase_modules().items(): for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.couchbase.couchbase_vector as module import dify_vdb_couchbase.couchbase_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,15 @@
[project]
name = "dify-vdb-elasticsearch"
version = "0.0.1"
dependencies = [
"elasticsearch==8.14.0",
]
description = "Dify vector store backend (dify-vdb-elasticsearch)."
[project.entry-points."dify.vector_backends"]
elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory"
elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -4,14 +4,14 @@ from typing import Any
from flask import current_app from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ( from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from dify_vdb_elasticsearch.elasticsearch_vector import (
ElasticSearchConfig, ElasticSearchConfig,
ElasticSearchVector, ElasticSearchVector,
ElasticSearchVectorFactory, ElasticSearchVectorFactory,
) )
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset

View File

@ -1,10 +1,9 @@
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class ElasticSearchVectorTest(AbstractVectorTest): class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self): def __init__(self):

View File

@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items(): for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module import dify_vdb_elasticsearch.elasticsearch_vector as base_module
importlib.reload(base_module) importlib.reload(base_module)
return importlib.reload(ja_module) return importlib.reload(ja_module)

View File

@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items(): for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module import dify_vdb_elasticsearch.elasticsearch_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-hologres"
version = "0.0.1"
dependencies = [
"holo-search-sdk>=0.4.2",
]
description = "Dify vector store backend (dify-vdb-hologres)."
[project.entry-points."dify.vector_backends"]
hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import time import time
from typing import Any from typing import Any, cast
import holo_search_sdk as holo # type: ignore import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory):
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "", access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "", access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
schema_name=dify_config.HOLOGRES_SCHEMA, schema_name=dify_config.HOLOGRES_SCHEMA,
tokenizer=dify_config.HOLOGRES_TOKENIZER, tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER),
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD, distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD),
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE, base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE),
max_degree=dify_config.HOLOGRES_MAX_DEGREE, max_degree=dify_config.HOLOGRES_MAX_DEGREE,
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION, ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
), ),

View File

@ -7,13 +7,10 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from psycopg import sql as psql from psycopg import sql as psql
# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}}
_mock_tables: dict[str, dict[str, dict[str, Any]]] = {} _mock_tables: dict[str, dict[str, dict[str, Any]]] = {}
class MockSearchQuery: class MockSearchQuery:
"""Mock query builder for search_vector and search_text results."""
def __init__(self, table_name: str, search_type: str): def __init__(self, table_name: str, search_type: str):
self._table_name = table_name self._table_name = table_name
self._search_type = search_type self._search_type = search_type
@ -32,17 +29,13 @@ class MockSearchQuery:
return self return self
def _apply_filter(self, row: dict[str, Any]) -> bool: def _apply_filter(self, row: dict[str, Any]) -> bool:
"""Apply the filter SQL to check if a row matches."""
if self._filter_sql is None: if self._filter_sql is None:
return True return True
# Extract literals (the document IDs) from the filter SQL
# Filter format: meta->>'document_id' IN ('doc1', 'doc2')
literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"] literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"]
if not literals: if not literals:
return True return True
# Get the document_id from the row's meta field
meta = row.get("meta", "{}") meta = row.get("meta", "{}")
if isinstance(meta, str): if isinstance(meta, str):
meta = json.loads(meta) meta = json.loads(meta)
@ -54,22 +47,17 @@ class MockSearchQuery:
data = _mock_tables.get(self._table_name, {}) data = _mock_tables.get(self._table_name, {})
results = [] results = []
for row in list(data.values())[: self._limit_val]: for row in list(data.values())[: self._limit_val]:
# Apply filter if present
if not self._apply_filter(row): if not self._apply_filter(row):
continue continue
if self._search_type == "vector": if self._search_type == "vector":
# row format expected by _process_vector_results: (distance, id, text, meta)
results.append((0.1, row["id"], row["text"], row["meta"])) results.append((0.1, row["id"], row["text"], row["meta"]))
else: else:
# row format expected by _process_full_text_results: (id, text, meta, embedding, score)
results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9)) results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9))
return results return results
class MockTable: class MockTable:
"""Mock table object returned by client.open_table()."""
def __init__(self, table_name: str): def __init__(self, table_name: str):
self._table_name = table_name self._table_name = table_name
@ -97,7 +85,6 @@ class MockTable:
def _extract_sql_template(query) -> str: def _extract_sql_template(query) -> str:
"""Extract the SQL template string from a psycopg Composed object."""
if isinstance(query, psql.Composed): if isinstance(query, psql.Composed):
for part in query: for part in query:
if isinstance(part, psql.SQL): if isinstance(part, psql.SQL):
@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str:
def _extract_identifiers_and_literals(query) -> list[Any]: def _extract_identifiers_and_literals(query) -> list[Any]:
"""Extract Identifier and Literal values from a psycopg Composed object."""
values: list[Any] = [] values: list[Any] = []
if isinstance(query, psql.Composed): if isinstance(query, psql.Composed):
for part in query: for part in query:
@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
elif isinstance(part, psql.Literal): elif isinstance(part, psql.Literal):
values.append(("literal", part._obj)) values.append(("literal", part._obj))
elif isinstance(part, psql.Composed): elif isinstance(part, psql.Composed):
# Handles SQL(...).join(...) for IN clauses
for sub in part: for sub in part:
if isinstance(sub, psql.Literal): if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj)) values.append(("literal", sub._obj))
@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
class MockHologresClient: class MockHologresClient:
"""Mock holo_search_sdk client that stores data in memory."""
def connect(self): def connect(self):
pass pass
@ -141,21 +124,18 @@ class MockHologresClient:
params = _extract_identifiers_and_literals(query) params = _extract_identifiers_and_literals(query)
if "CREATE TABLE" in template.upper(): if "CREATE TABLE" in template.upper():
# Extract table name from first identifier
table_name = next((v for t, v in params if t == "ident"), "unknown") table_name = next((v for t, v in params if t == "ident"), "unknown")
if table_name not in _mock_tables: if table_name not in _mock_tables:
_mock_tables[table_name] = {} _mock_tables[table_name] = {}
return None return None
if "SELECT 1" in template: if "SELECT 1" in template:
# text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1
table_name = next((v for t, v in params if t == "ident"), "") table_name = next((v for t, v in params if t == "ident"), "")
doc_id = next((v for t, v in params if t == "literal"), "") doc_id = next((v for t, v in params if t == "literal"), "")
data = _mock_tables.get(table_name, {}) data = _mock_tables.get(table_name, {})
return [(1,)] if doc_id in data else [] return [(1,)] if doc_id in data else []
if "SELECT id" in template: if "SELECT id" in template:
# get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value}
table_name = next((v for t, v in params if t == "ident"), "") table_name = next((v for t, v in params if t == "ident"), "")
literals = [v for t, v in params if t == "literal"] literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else "" key = literals[0] if len(literals) > 0 else ""
@ -166,12 +146,10 @@ class MockHologresClient:
if "DELETE" in template.upper(): if "DELETE" in template.upper():
table_name = next((v for t, v in params if t == "ident"), "") table_name = next((v for t, v in params if t == "ident"), "")
if "id IN" in template: if "id IN" in template:
# delete_by_ids
ids_to_delete = [v for t, v in params if t == "literal"] ids_to_delete = [v for t, v in params if t == "literal"]
for did in ids_to_delete: for did in ids_to_delete:
_mock_tables.get(table_name, {}).pop(did, None) _mock_tables.get(table_name, {}).pop(did, None)
elif "meta->>" in template: elif "meta->>" in template:
# delete_by_metadata_field
literals = [v for t, v in params if t == "literal"] literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else "" key = literals[0] if len(literals) > 0 else ""
value = literals[1] if len(literals) > 1 else "" value = literals[1] if len(literals) > 1 else ""
@ -190,7 +168,6 @@ class MockHologresClient:
def mock_connect(**kwargs): def mock_connect(**kwargs):
"""Replacement for holo_search_sdk.connect() that returns a mock client."""
return MockHologresClient() return MockHologresClient()

View File

@ -2,16 +2,11 @@ import os
import uuid import uuid
from typing import cast from typing import cast
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
from core.rag.models.document import Document from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
pytest_plugins = (
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.hologres",
)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"

View File

@ -42,7 +42,7 @@ def hologres_module(monkeypatch):
for name, module in _build_fake_hologres_modules().items(): for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.hologres.hologres_vector as module import dify_vdb_hologres.hologres_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-huawei-cloud"
version = "0.0.1"
dependencies = [
"elasticsearch==8.14.0",
]
description = "Dify vector store backend (dify-vdb-huawei-cloud)."
[project.entry-points."dify.vector_backends"]
huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,10 +1,6 @@
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
pytest_plugins = ( from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.huaweicloudvectordb",
)
class HuaweiCloudVectorTest(AbstractVectorTest): class HuaweiCloudVectorTest(AbstractVectorTest):

View File

@ -33,7 +33,7 @@ def huawei_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items(): for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module import dify_vdb_huawei_cloud.huawei_cloud_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-iris"
version = "0.0.1"
dependencies = [
"intersystems-irispython>=5.1.0",
]
description = "Dify vector store backend (dify-vdb-iris)."
[project.entry-points."dify.vector_backends"]
iris = "dify_vdb_iris.iris_vector:IrisVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,12 +1,11 @@
"""Integration tests for IRIS vector database.""" """Integration tests for IRIS vector database."""
from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig from dify_vdb_iris.iris_vector import IrisVector, IrisVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class IrisVectorTest(AbstractVectorTest): class IrisVectorTest(AbstractVectorTest):
"""Test suite for IRIS vector store implementation.""" """Test suite for IRIS vector store implementation."""

View File

@ -26,7 +26,7 @@ def _build_fake_iris_module():
def iris_module(monkeypatch): def iris_module(monkeypatch):
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
import core.rag.datasource.vdb.iris.iris_vector as module import dify_vdb_iris.iris_vector as module
reloaded = importlib.reload(module) reloaded = importlib.reload(module)
reloaded._pool_instance = None reloaded._pool_instance = None

View File

@ -0,0 +1,15 @@
[project]
name = "dify-vdb-lindorm"
version = "0.0.1"
dependencies = [
"opensearch-py==3.1.0",
"tenacity>=8.0.0",
]
description = "Dify vector store backend (dify-vdb-lindorm)."
[project.entry-points."dify.vector_backends"]
lindorm = "dify_vdb_lindorm.lindorm_vector:LindormVectorStoreFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,9 +1,8 @@
import os import os
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
class Config: class Config:

View File

@ -51,7 +51,7 @@ def lindorm_module(monkeypatch):
for name, module in _build_fake_opensearch_modules().items(): for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.lindorm.lindorm_vector as module import dify_vdb_lindorm.lindorm_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-matrixone"
version = "0.0.1"
dependencies = [
"mo-vector~=0.1.13",
]
description = "Dify vector store backend (dify-vdb-matrixone)."
[project.entry-points."dify.vector_backends"]
matrixone = "dify_vdb_matrixone.matrixone_vector:MatrixoneVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,10 +1,9 @@
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class MatrixoneVectorTest(AbstractVectorTest): class MatrixoneVectorTest(AbstractVectorTest):
def __init__(self): def __init__(self):

View File

@ -36,7 +36,7 @@ def matrixone_module(monkeypatch):
for name, module in _build_fake_mo_vector_modules().items(): for name, module in _build_fake_mo_vector_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.matrixone.matrixone_vector as module import dify_vdb_matrixone.matrixone_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-milvus"
version = "0.0.1"
dependencies = [
"pymilvus~=2.6.12",
]
description = "Dify vector store backend (dify-vdb-milvus)."
[project.entry-points."dify.vector_backends"]
milvus = "dify_vdb_milvus.milvus_vector:MilvusVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,11 +1,10 @@
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
get_example_text, get_example_text,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class MilvusVectorTest(AbstractVectorTest): class MilvusVectorTest(AbstractVectorTest):
def __init__(self): def __init__(self):

View File

@ -103,7 +103,7 @@ def milvus_module(monkeypatch):
for name, module in _build_fake_pymilvus_modules().items(): for name, module in _build_fake_pymilvus_modules().items():
monkeypatch.setitem(sys.modules, name, module) monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.milvus.milvus_vector as module import dify_vdb_milvus.milvus_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,14 @@
[project]
name = "dify-vdb-myscale"
version = "0.0.1"
dependencies = [
"clickhouse-connect~=0.15.0",
]
description = "Dify vector store backend (dify-vdb-myscale)."
[project.entry-points."dify.vector_backends"]
myscale = "dify_vdb_myscale.myscale_vector:MyScaleVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,10 +1,9 @@
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class MyScaleVectorTest(AbstractVectorTest): class MyScaleVectorTest(AbstractVectorTest):
def __init__(self): def __init__(self):

View File

@ -42,7 +42,7 @@ def myscale_module(monkeypatch):
fake_module = _build_fake_clickhouse_connect_module() fake_module = _build_fake_clickhouse_connect_module()
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
import core.rag.datasource.vdb.myscale.myscale_vector as module import dify_vdb_myscale.myscale_vector as module
return importlib.reload(module) return importlib.reload(module)

View File

@ -0,0 +1,16 @@
[project]
name = "dify-vdb-oceanbase"
version = "0.0.1"
dependencies = [
"pyobvector~=0.2.17",
"mysql-connector-python>=9.3.0",
]
description = "Dify vector store backend (dify-vdb-oceanbase)."
[project.entry-points."dify.vector_backends"]
oceanbase = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory"
seekdb = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -2,11 +2,12 @@
Benchmark: OceanBase vector store old (single-row) vs new (batch) insertion, Benchmark: OceanBase vector store old (single-row) vs new (batch) insertion,
metadata query with/without functional index, and vector search across metrics. metadata query with/without functional index, and vector search across metrics.
Usage: Usage (from repo root):
uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase uv run --project api python api/packages/dify-vdb-oceanbase/tests/bench_oceanbase.py
""" """
import json import json
import logging
import random import random
import statistics import statistics
import time import time
@ -16,6 +17,8 @@ from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_d
from sqlalchemy import JSON, Column, String, text from sqlalchemy import JSON, Column, String, text
from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.dialects.mysql import LONGTEXT
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Config # Config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -114,7 +117,7 @@ def bench_metadata_query(client, table, doc_id, with_index=False):
try: try:
client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))") client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))")
except Exception: except Exception:
pass # already exists logger.debug("Index idx_metadata_doc_id already exists, skipping creation")
sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val") sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val")
times = [] times = []
@ -164,11 +167,11 @@ def main():
client = _make_client() client = _make_client()
client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True) client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True)
print("=" * 70) logger.info("=" * 70)
print("OceanBase Vector Store — Performance Benchmark") logger.info("OceanBase Vector Store — Performance Benchmark")
print(f" Endpoint : {HOST}:{PORT}") logger.info(" Endpoint : %s:%s", HOST, PORT)
print(f" Vec dim : {VEC_DIM}") logger.info(" Vec dim : %s", VEC_DIM)
print("=" * 70) logger.info("=" * 70)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 1. Insertion benchmark # 1. Insertion benchmark
@ -187,10 +190,10 @@ def main():
t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100) t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100)
speedup = t_single / t_batch if t_batch > 0 else float("inf") speedup = t_single / t_batch if t_batch > 0 else float("inf")
print(f"\n[Insert {n_docs} docs]") logger.info("\n[Insert %s docs]", n_docs)
print(f" Single-row : {t_single:.2f}s") logger.info(" Single-row : %.2fs", t_single)
print(f" Batch(100) : {t_batch:.2f}s") logger.info(" Batch(100) : %.2fs", t_batch)
print(f" Speedup : {speedup:.1f}x") logger.info(" Speedup : %.1fx", speedup)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 2. Metadata query benchmark (use the 1000-doc batch table) # 2. Metadata query benchmark (use the 1000-doc batch table)
@ -203,16 +206,16 @@ def main():
res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1")) res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1"))
doc_id_1000 = res.fetchone()[0] doc_id_1000 = res.fetchone()[0]
print("\n[Metadata filter query — 1000 rows, by document_id]") logger.info("\n[Metadata filter query — 1000 rows, by document_id]")
times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False) times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False)
print(f" Without index : {_fmt(times_no_idx)}") logger.info(" Without index : %s", _fmt(times_no_idx))
times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True) times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True)
print(f" With index : {_fmt(times_with_idx)}") logger.info(" With index : %s", _fmt(times_with_idx))
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# 3. Vector search benchmark — across metrics # 3. Vector search benchmark — across metrics
# ------------------------------------------------------------------ # ------------------------------------------------------------------
print("\n[Vector search — top-10, 20 queries each, on 1000 rows]") logger.info("\n[Vector search — top-10, 20 queries each, on 1000 rows]")
for metric in ["l2", "cosine", "inner_product"]: for metric in ["l2", "cosine", "inner_product"]:
tbl_vs = f"bench_vs_{metric}" tbl_vs = f"bench_vs_{metric}"
@ -222,7 +225,7 @@ def main():
rows_vs, _ = _gen_rows(1000) rows_vs, _ = _gen_rows(1000)
bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100) bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100)
times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20) times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20)
print(f" {metric:15s}: {_fmt(times)}") logger.info(" %-15s: %s", metric, _fmt(times))
_drop(client_pooled, tbl_vs) _drop(client_pooled, tbl_vs)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -232,9 +235,9 @@ def main():
_drop(client, f"bench_single_{n}") _drop(client, f"bench_single_{n}")
_drop(client, f"bench_batch_{n}") _drop(client, f"bench_batch_{n}")
print("\n" + "=" * 70) logger.info("\n%s", "=" * 70)
print("Benchmark complete.") logger.info("Benchmark complete.")
print("=" * 70) logger.info("=" * 70)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,15 +1,13 @@
import pytest import pytest
from dify_vdb_oceanbase.oceanbase_vector import (
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
OceanBaseVector, OceanBaseVector,
OceanBaseVectorConfig, OceanBaseVectorConfig,
) )
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest, AbstractVectorTest,
) )
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
@pytest.fixture @pytest.fixture
def oceanbase_vector(): def oceanbase_vector():

View File

@ -56,7 +56,7 @@ def _build_fake_pyobvector_module():
def oceanbase_module(monkeypatch): def oceanbase_module(monkeypatch):
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module import dify_vdb_oceanbase.oceanbase_vector as module
return importlib.reload(module) return importlib.reload(module)

Some files were not shown because too many files have changed in this diff Show More