mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 14:51:13 +08:00
db_migration_utils.py: add try_create_db_if_not_exists & test
This commit is contained in:
parent
aad0b3c157
commit
3613c7bbab
95
api/libs/db_migration_utils.py
Normal file
95
api/libs/db_migration_utils.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Database migration utility helpers.
|
||||
|
||||
These are intentionally migration-specific utilities. Do NOT use them in normal
|
||||
application code paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def try_create_db_if_not_exists(
|
||||
db_type: str,
|
||||
host: str,
|
||||
port: int,
|
||||
username: str,
|
||||
password: str,
|
||||
database: str,
|
||||
) -> None:
|
||||
"""Best-effort attempt to create the target database if it does not exist.
|
||||
|
||||
Only supports PostgreSQL and MySQL. For other database types this function
|
||||
is a no-op. Failures are logged as warnings and never re-raised, so callers
|
||||
(e.g. the migration command) are not interrupted when the database already
|
||||
exists or when the user lacks CREATE DATABASE privileges.
|
||||
|
||||
Args:
|
||||
db_type: One of the supported DB_TYPE values (e.g. "postgresql", "mysql").
|
||||
host: Database server hostname or IP.
|
||||
port: Database server port.
|
||||
username: Database username.
|
||||
password: Database password.
|
||||
database: Name of the database to create if absent.
|
||||
"""
|
||||
try:
|
||||
if db_type == "postgresql":
|
||||
_try_create_postgresql(host, port, username, password, database)
|
||||
elif db_type == "mysql":
|
||||
_try_create_mysql(host, port, username, password, database)
|
||||
else:
|
||||
logger.debug(
|
||||
"try_create_db_if_not_exists: unsupported db_type=%r, skipping.",
|
||||
db_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"try_create_db_if_not_exists: failed to create database %r (db_type=%r). "
|
||||
"Proceeding anyway — migration will fail if the database truly does not exist.",
|
||||
database,
|
||||
db_type,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def _try_create_postgresql(host: str, port: int, username: str, password: str, database: str) -> None:
|
||||
# Connect to the default 'postgres' maintenance database so we can issue
|
||||
# CREATE DATABASE without requiring the target database to already exist.
|
||||
admin_uri = f"postgresql://{quote_plus(username)}:{quote_plus(password)}@{host}:{port}/postgres"
|
||||
engine = sa.create_engine(admin_uri, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
exists = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": database},
|
||||
).scalar()
|
||||
if not exists:
|
||||
# Identifier quoting guards against names with special characters.
|
||||
conn.execute(sa.text(f'CREATE DATABASE "{database}"'))
|
||||
logger.info("try_create_db_if_not_exists: PostgreSQL database %r created.", database)
|
||||
else:
|
||||
logger.debug("try_create_db_if_not_exists: PostgreSQL database %r already exists.", database)
|
||||
finally:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def _try_create_mysql(host: str, port: int, username: str, password: str, database: str) -> None:
|
||||
# Connect without specifying a database so the target need not exist yet.
|
||||
admin_uri = f"mysql+pymysql://{quote_plus(username)}:{quote_plus(password)}@{host}:{port}/"
|
||||
engine = sa.create_engine(admin_uri)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# MySQL supports IF NOT EXISTS natively — no need to check first.
|
||||
conn.execute(
|
||||
sa.text(f"CREATE DATABASE IF NOT EXISTS `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
|
||||
)
|
||||
conn.commit()
|
||||
logger.info("try_create_db_if_not_exists: MySQL database %r ensured.", database)
|
||||
finally:
|
||||
engine.dispose()
|
||||
152
api/tests/unit_tests/libs/test_db_migration_utils.py
Normal file
152
api/tests/unit_tests/libs/test_db_migration_utils.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""Unit tests for libs.db_migration_utils."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from libs.db_migration_utils import try_create_db_if_not_exists
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PG_ARGS = {"host": "localhost", "port": 5432, "username": "postgres", "password": "secret", "database": "dify"}
|
||||
_MY_ARGS = {"host": "localhost", "port": 3306, "username": "root", "password": "secret", "database": "dify"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unsupported DB types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnsupportedDbType:
|
||||
def test_oceanbase_is_noop(self):
|
||||
"""Unsupported db_type should not raise and should not touch SQLAlchemy."""
|
||||
with patch("libs.db_migration_utils.sa.create_engine") as mock_engine:
|
||||
try_create_db_if_not_exists(db_type="oceanbase", **_PG_ARGS)
|
||||
mock_engine.assert_not_called()
|
||||
|
||||
def test_seekdb_is_noop(self):
|
||||
with patch("libs.db_migration_utils.sa.create_engine") as mock_engine:
|
||||
try_create_db_if_not_exists(db_type="seekdb", **_PG_ARGS)
|
||||
mock_engine.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPostgreSQL:
|
||||
def _make_engine_mock(self, db_exists: bool) -> MagicMock:
|
||||
"""Return a mock engine whose connection reports whether the DB exists."""
|
||||
scalar_result = 1 if db_exists else None
|
||||
conn = MagicMock()
|
||||
conn.execute.return_value.scalar.return_value = scalar_result
|
||||
conn.__enter__ = MagicMock(return_value=conn)
|
||||
conn.__exit__ = MagicMock(return_value=False)
|
||||
engine = MagicMock()
|
||||
engine.connect.return_value = conn
|
||||
return engine, conn
|
||||
|
||||
def test_creates_database_when_not_exists(self):
|
||||
"""Should issue CREATE DATABASE when pg_database lookup returns nothing."""
|
||||
engine, conn = self._make_engine_mock(db_exists=False)
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS)
|
||||
|
||||
# First call: SELECT from pg_database; second call: CREATE DATABASE
|
||||
assert conn.execute.call_count == 2
|
||||
create_call_sql = str(conn.execute.call_args_list[1][0][0])
|
||||
assert "CREATE DATABASE" in create_call_sql
|
||||
|
||||
def test_skips_create_when_database_exists(self):
|
||||
"""Should NOT issue CREATE DATABASE when the database already exists."""
|
||||
engine, conn = self._make_engine_mock(db_exists=True)
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS)
|
||||
|
||||
assert conn.execute.call_count == 1 # only the SELECT
|
||||
|
||||
def test_connects_to_postgres_maintenance_db(self):
|
||||
"""Admin connection must target the 'postgres' maintenance database, not the target DB."""
|
||||
engine, _ = self._make_engine_mock(db_exists=True)
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine) as mock_create:
|
||||
try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS)
|
||||
|
||||
uri_used: str = mock_create.call_args[0][0]
|
||||
assert uri_used.endswith("/postgres"), f"Expected URI ending in /postgres, got: {uri_used}"
|
||||
|
||||
def test_engine_disposed_on_success(self):
|
||||
engine, _ = self._make_engine_mock(db_exists=True)
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS)
|
||||
engine.dispose.assert_called_once()
|
||||
|
||||
def test_engine_disposed_on_connection_error(self):
|
||||
"""dispose() must be called even when the connection raises."""
|
||||
engine = MagicMock()
|
||||
engine.connect.side_effect = Exception("connection refused")
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
# Should not raise — failure is swallowed and logged as warning
|
||||
try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS)
|
||||
engine.dispose.assert_called_once()
|
||||
|
||||
def test_exception_is_swallowed_not_raised(self):
|
||||
"""Any exception during DB creation must be caught; caller must not see it."""
|
||||
engine = MagicMock()
|
||||
engine.connect.side_effect = RuntimeError("boom")
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) # must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MySQL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMySQL:
|
||||
def _make_engine_mock(self) -> tuple[MagicMock, MagicMock]:
|
||||
conn = MagicMock()
|
||||
conn.__enter__ = MagicMock(return_value=conn)
|
||||
conn.__exit__ = MagicMock(return_value=False)
|
||||
engine = MagicMock()
|
||||
engine.connect.return_value = conn
|
||||
return engine, conn
|
||||
|
||||
def test_issues_create_database_if_not_exists(self):
|
||||
"""MySQL path should always call CREATE DATABASE IF NOT EXISTS."""
|
||||
engine, conn = self._make_engine_mock()
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS)
|
||||
|
||||
assert conn.execute.called
|
||||
sql = str(conn.execute.call_args[0][0])
|
||||
assert "CREATE DATABASE IF NOT EXISTS" in sql
|
||||
|
||||
def test_commits_after_create(self):
|
||||
engine, conn = self._make_engine_mock()
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS)
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
def test_connects_without_database_in_uri(self):
|
||||
"""MySQL admin URI must not include the target database name."""
|
||||
engine, _ = self._make_engine_mock()
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine) as mock_create:
|
||||
try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS)
|
||||
|
||||
uri_used: str = mock_create.call_args[0][0]
|
||||
assert uri_used.endswith("/"), f"Expected URI ending in '/', got: {uri_used}"
|
||||
|
||||
def test_engine_disposed_on_success(self):
|
||||
engine, _ = self._make_engine_mock()
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS)
|
||||
engine.dispose.assert_called_once()
|
||||
|
||||
def test_exception_is_swallowed_not_raised(self):
|
||||
engine = MagicMock()
|
||||
engine.connect.side_effect = RuntimeError("boom")
|
||||
with patch("libs.db_migration_utils.sa.create_engine", return_value=engine):
|
||||
try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS) # must not raise
|
||||
Loading…
Reference in New Issue
Block a user