diff --git a/api/commands/system.py b/api/commands/system.py index 7755d3b5bcd..a3b0ceb092d 100644 --- a/api/commands/system.py +++ b/api/commands/system.py @@ -10,6 +10,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.db_migration_lock import DbMigrationAutoRenewLock +from libs.db_migration_utils import try_create_db_if_not_exists from libs.rsa import generate_key_pair from models import Tenant from models.model import App, AppMode, Conversation @@ -147,6 +148,16 @@ def upgrade_db(): try: click.echo(click.style("Starting database migration.", fg="green")) + # ensure the target database exists before migrations + try_create_db_if_not_exists( + db_type=dify_config.DB_TYPE, + host=dify_config.DB_HOST, + port=dify_config.DB_PORT, + username=dify_config.DB_USERNAME, + password=dify_config.DB_PASSWORD, + database=dify_config.DB_DATABASE, + ) + # run db migration import flask_migrate diff --git a/api/libs/db_migration_utils.py b/api/libs/db_migration_utils.py new file mode 100644 index 00000000000..a69b6e6d5bd --- /dev/null +++ b/api/libs/db_migration_utils.py @@ -0,0 +1,95 @@ +""" +Database migration utility helpers. + +These are intentionally migration-specific utilities. Do NOT use them in normal +application code paths. +""" + +from __future__ import annotations + +import logging +from urllib.parse import quote_plus + +import sqlalchemy as sa + +logger = logging.getLogger(__name__) + + +def try_create_db_if_not_exists( + db_type: str, + host: str, + port: int, + username: str, + password: str, + database: str, +) -> None: + """Best-effort attempt to create the target database if it does not exist. + + Only supports PostgreSQL and MySQL. For other database types this function + is a no-op. Failures are logged as warnings and never re-raised, so callers + (e.g. the migration command) are not interrupted when the database already + exists or when the user lacks CREATE DATABASE privileges. + + Args: + db_type: One of the supported DB_TYPE values (e.g. "postgresql", "mysql"). + host: Database server hostname or IP. + port: Database server port. + username: Database username. + password: Database password. + database: Name of the database to create if absent. + """ + try: + if db_type == "postgresql": + _try_create_postgresql(host, port, username, password, database) + elif db_type == "mysql": + _try_create_mysql(host, port, username, password, database) + else: + logger.debug( + "try_create_db_if_not_exists: unsupported db_type=%r, skipping.", + db_type, + ) + except Exception: + logger.warning( + "try_create_db_if_not_exists: failed to create database %r (db_type=%r). " + "Proceeding anyway — migration will fail if the database truly does not exist.", + database, + db_type, + exc_info=True, + ) + + +def _try_create_postgresql(host: str, port: int, username: str, password: str, database: str) -> None: + # Connect to the default 'postgres' maintenance database so we can issue + # CREATE DATABASE without requiring the target database to already exist. + admin_uri = f"postgresql://{quote_plus(username)}:{quote_plus(password)}@{host}:{port}/postgres" + engine = sa.create_engine(admin_uri, isolation_level="AUTOCOMMIT") + try: + with engine.connect() as conn: + exists = conn.execute( + sa.text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": database}, + ).scalar() + if not exists: + # Identifier quoting guards against names with special characters. + conn.execute(sa.text(f'CREATE DATABASE "{database}"')) + logger.info("try_create_db_if_not_exists: PostgreSQL database %r created.", database) + else: + logger.debug("try_create_db_if_not_exists: PostgreSQL database %r already exists.", database) + finally: + engine.dispose() + + +def _try_create_mysql(host: str, port: int, username: str, password: str, database: str) -> None: + # Connect without specifying a database so the target need not exist yet. + admin_uri = f"mysql+pymysql://{quote_plus(username)}:{quote_plus(password)}@{host}:{port}/" + engine = sa.create_engine(admin_uri) + try: + with engine.connect() as conn: + # MySQL supports IF NOT EXISTS natively — no need to check first. + conn.execute( + sa.text(f"CREATE DATABASE IF NOT EXISTS `{database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") + ) + conn.commit() + logger.info("try_create_db_if_not_exists: MySQL database %r ensured.", database) + finally: + engine.dispose() diff --git a/api/tests/unit_tests/libs/test_db_migration_utils.py b/api/tests/unit_tests/libs/test_db_migration_utils.py new file mode 100644 index 00000000000..c3738f1f3d1 --- /dev/null +++ b/api/tests/unit_tests/libs/test_db_migration_utils.py @@ -0,0 +1,152 @@ +"""Unit tests for libs.db_migration_utils.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from libs.db_migration_utils import try_create_db_if_not_exists + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_PG_ARGS = {"host": "localhost", "port": 5432, "username": "postgres", "password": "secret", "database": "dify"} +_MY_ARGS = {"host": "localhost", "port": 3306, "username": "root", "password": "secret", "database": "dify"} + + +# --------------------------------------------------------------------------- +# Unsupported DB types +# --------------------------------------------------------------------------- + + +class TestUnsupportedDbType: + def test_oceanbase_is_noop(self): + """Unsupported db_type should not raise and should not touch SQLAlchemy.""" + with patch("libs.db_migration_utils.sa.create_engine") as mock_engine: + try_create_db_if_not_exists(db_type="oceanbase", **_PG_ARGS) + mock_engine.assert_not_called() + + def test_seekdb_is_noop(self): + with patch("libs.db_migration_utils.sa.create_engine") as mock_engine: + try_create_db_if_not_exists(db_type="seekdb", **_PG_ARGS) + mock_engine.assert_not_called() + + +# --------------------------------------------------------------------------- +# PostgreSQL +# --------------------------------------------------------------------------- + + +class TestPostgreSQL: + def _make_engine_mock(self, db_exists: bool) -> MagicMock: + """Return a mock engine whose connection reports whether the DB exists.""" + scalar_result = 1 if db_exists else None + conn = MagicMock() + conn.execute.return_value.scalar.return_value = scalar_result + conn.__enter__ = MagicMock(return_value=conn) + conn.__exit__ = MagicMock(return_value=False) + engine = MagicMock() + engine.connect.return_value = conn + return engine, conn + + def test_creates_database_when_not_exists(self): + """Should issue CREATE DATABASE when pg_database lookup returns nothing.""" + engine, conn = self._make_engine_mock(db_exists=False) + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) + + # First call: SELECT from pg_database; second call: CREATE DATABASE + assert conn.execute.call_count == 2 + create_call_sql = str(conn.execute.call_args_list[1][0][0]) + assert "CREATE DATABASE" in create_call_sql + + def test_skips_create_when_database_exists(self): + """Should NOT issue CREATE DATABASE when the database already exists.""" + engine, conn = self._make_engine_mock(db_exists=True) + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) + + assert conn.execute.call_count == 1 # only the SELECT + + def test_connects_to_postgres_maintenance_db(self): + """Admin connection must target the 'postgres' maintenance database, not the target DB.""" + engine, _ = self._make_engine_mock(db_exists=True) + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine) as mock_create: + try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) + + uri_used: str = mock_create.call_args[0][0] + assert uri_used.endswith("/postgres"), f"Expected URI ending in /postgres, got: {uri_used}" + + def test_engine_disposed_on_success(self): + engine, _ = self._make_engine_mock(db_exists=True) + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) + engine.dispose.assert_called_once() + + def test_engine_disposed_on_connection_error(self): + """dispose() must be called even when the connection raises.""" + engine = MagicMock() + engine.connect.side_effect = Exception("connection refused") + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + # Should not raise — failure is swallowed and logged as warning + try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) + engine.dispose.assert_called_once() + + def test_exception_is_swallowed_not_raised(self): + """Any exception during DB creation must be caught; caller must not see it.""" + engine = MagicMock() + engine.connect.side_effect = RuntimeError("boom") + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="postgresql", **_PG_ARGS) # must not raise + + +# --------------------------------------------------------------------------- +# MySQL +# --------------------------------------------------------------------------- + + +class TestMySQL: + def _make_engine_mock(self) -> tuple[MagicMock, MagicMock]: + conn = MagicMock() + conn.__enter__ = MagicMock(return_value=conn) + conn.__exit__ = MagicMock(return_value=False) + engine = MagicMock() + engine.connect.return_value = conn + return engine, conn + + def test_issues_create_database_if_not_exists(self): + """MySQL path should always call CREATE DATABASE IF NOT EXISTS.""" + engine, conn = self._make_engine_mock() + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS) + + assert conn.execute.called + sql = str(conn.execute.call_args[0][0]) + assert "CREATE DATABASE IF NOT EXISTS" in sql + + def test_commits_after_create(self): + engine, conn = self._make_engine_mock() + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS) + conn.commit.assert_called_once() + + def test_connects_without_database_in_uri(self): + """MySQL admin URI must not include the target database name.""" + engine, _ = self._make_engine_mock() + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine) as mock_create: + try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS) + + uri_used: str = mock_create.call_args[0][0] + assert uri_used.endswith("/"), f"Expected URI ending in '/', got: {uri_used}" + + def test_engine_disposed_on_success(self): + engine, _ = self._make_engine_mock() + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS) + engine.dispose.assert_called_once() + + def test_exception_is_swallowed_not_raised(self): + engine = MagicMock() + engine.connect.side_effect = RuntimeError("boom") + with patch("libs.db_migration_utils.sa.create_engine", return_value=engine): + try_create_db_if_not_exists(db_type="mysql", **_MY_ARGS) # must not raise