mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 11:56:55 +08:00
refactor(api): replace heartbeat mechanism with AutoRenewRedisLock for database migration
- Removed the manual heartbeat function for renewing the Redis lock during database migrations. - Integrated AutoRenewRedisLock to handle lock renewal automatically, simplifying the upgrade_db command. - Updated unit tests to reflect changes in lock handling and error management during migrations. (cherry picked from commit 8814256eb5fa20b29e554264f3b659b027bc4c9a)
This commit is contained in:
parent
8d4bd5636b
commit
94603b5408
@ -3,15 +3,13 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from redis.exceptions import LockNotOwnedError, RedisError
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
@ -32,6 +30,7 @@ from extensions.ext_redis import redis_client
|
|||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from extensions.storage.opendal_storage import OpenDALStorage
|
from extensions.storage.opendal_storage import OpenDALStorage
|
||||||
from extensions.storage.storage_type import StorageType
|
from extensions.storage.storage_type import StorageType
|
||||||
|
from libs.auto_renew_redis_lock import AutoRenewRedisLock
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from libs.password import hash_password, password_pattern, valid_password
|
from libs.password import hash_password, password_pattern, valid_password
|
||||||
from libs.rsa import generate_key_pair
|
from libs.rsa import generate_key_pair
|
||||||
@ -56,36 +55,9 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from redis.lock import Lock
|
|
||||||
|
|
||||||
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
def _heartbeat_db_upgrade_lock(lock: "Lock", stop_event: threading.Event, ttl_seconds: float) -> None:
|
|
||||||
"""
|
|
||||||
Keep the DB upgrade lock alive while migrations are running.
|
|
||||||
|
|
||||||
We intentionally keep the base TTL small (e.g. 60s) so that if the process is killed and can't
|
|
||||||
release the lock, the lock will naturally expire soon. While the process is alive, this
|
|
||||||
heartbeat periodically resets the TTL via `lock.reacquire()`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
interval_seconds = max(0.1, ttl_seconds / 3)
|
|
||||||
while not stop_event.wait(interval_seconds):
|
|
||||||
try:
|
|
||||||
lock.reacquire()
|
|
||||||
except LockNotOwnedError:
|
|
||||||
# Another process took over / TTL expired; continuing to retry won't help.
|
|
||||||
logger.warning("DB migration lock is no longer owned during heartbeat; stop renewing.")
|
|
||||||
return
|
|
||||||
except RedisError:
|
|
||||||
# Best-effort: keep trying while the process is alive.
|
|
||||||
logger.warning("Failed to renew DB migration lock due to Redis error; will retry.", exc_info=True)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Unexpected error while renewing DB migration lock; will retry.", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
@click.command("reset-password", help="Reset the account password.")
|
@click.command("reset-password", help="Reset the account password.")
|
||||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||||
@click.option("--new-password", prompt=True, help="New password")
|
@click.option("--new-password", prompt=True, help="New password")
|
||||||
@ -758,21 +730,14 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
|||||||
@click.command("upgrade-db", help="Upgrade the database")
|
@click.command("upgrade-db", help="Upgrade the database")
|
||||||
def upgrade_db():
|
def upgrade_db():
|
||||||
click.echo("Preparing database migration...")
|
click.echo("Preparing database migration...")
|
||||||
# Use a short base TTL + heartbeat renewal, so a crashed process doesn't block migrations for long.
|
lock = AutoRenewRedisLock(
|
||||||
# thread_local=False is required because heartbeat runs in a separate thread.
|
redis_client=redis_client,
|
||||||
lock = redis_client.lock(
|
|
||||||
name="db_upgrade_lock",
|
name="db_upgrade_lock",
|
||||||
timeout=DB_UPGRADE_LOCK_TTL_SECONDS,
|
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||||
thread_local=False,
|
logger=logger,
|
||||||
|
log_context="db_migration",
|
||||||
)
|
)
|
||||||
if lock.acquire(blocking=False):
|
if lock.acquire(blocking=False):
|
||||||
stop_event = threading.Event()
|
|
||||||
heartbeat_thread = threading.Thread(
|
|
||||||
target=_heartbeat_db_upgrade_lock,
|
|
||||||
args=(lock, stop_event, float(DB_UPGRADE_LOCK_TTL_SECONDS)),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
heartbeat_thread.start()
|
|
||||||
migration_succeeded = False
|
migration_succeeded = False
|
||||||
try:
|
try:
|
||||||
click.echo(click.style("Starting database migration.", fg="green"))
|
click.echo(click.style("Starting database migration.", fg="green"))
|
||||||
@ -790,23 +755,8 @@ def upgrade_db():
|
|||||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
finally:
|
finally:
|
||||||
stop_event.set()
|
status = "successful" if migration_succeeded else "failed"
|
||||||
heartbeat_thread.join(timeout=5)
|
lock.release_safely(status=status)
|
||||||
# Lock release errors should never mask the real migration failure.
|
|
||||||
try:
|
|
||||||
lock.release()
|
|
||||||
except LockNotOwnedError:
|
|
||||||
status = "successful" if migration_succeeded else "failed"
|
|
||||||
logger.warning(
|
|
||||||
"DB migration lock not owned on release after %s migration (likely expired); ignoring.", status
|
|
||||||
)
|
|
||||||
except RedisError:
|
|
||||||
status = "successful" if migration_succeeded else "failed"
|
|
||||||
logger.warning(
|
|
||||||
"Failed to release DB migration lock due to Redis error after %s migration; ignoring.",
|
|
||||||
status,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
click.echo("Database migration skipped")
|
click.echo("Database migration skipped")
|
||||||
|
|
||||||
|
|||||||
198
api/libs/auto_renew_redis_lock.py
Normal file
198
api/libs/auto_renew_redis_lock.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
Auto-renewing Redis distributed lock (redis-py Lock).
|
||||||
|
|
||||||
|
Why this exists:
|
||||||
|
- A fixed, long lock TTL can leave a stale lock for a long time if the process is killed
|
||||||
|
before releasing it.
|
||||||
|
- A fixed, short lock TTL can expire during long critical sections (e.g. DB migrations),
|
||||||
|
allowing another instance to acquire the same lock concurrently.
|
||||||
|
|
||||||
|
This wrapper keeps a short base TTL and renews it in a daemon thread using `Lock.reacquire()`
|
||||||
|
while the process is alive. If the process is terminated, the renewal stops and the lock
|
||||||
|
expires soon.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from redis.exceptions import LockNotOwnedError, RedisError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoRenewRedisLock:
|
||||||
|
"""
|
||||||
|
Redis lock wrapper that automatically renews TTL while held.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- We force `thread_local=False` when creating the underlying redis-py lock, because the
|
||||||
|
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
|
||||||
|
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
|
||||||
|
primary error/exit code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_redis_client: Any
|
||||||
|
_name: str
|
||||||
|
_ttl_seconds: float
|
||||||
|
_renew_interval_seconds: float
|
||||||
|
_log_context: str | None
|
||||||
|
_logger: logging.Logger
|
||||||
|
|
||||||
|
_lock: Any
|
||||||
|
_stop_event: threading.Event | None
|
||||||
|
_thread: threading.Thread | None
|
||||||
|
_acquired: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_client: Any,
|
||||||
|
name: str,
|
||||||
|
ttl_seconds: float = 60,
|
||||||
|
renew_interval_seconds: float | None = None,
|
||||||
|
*,
|
||||||
|
logger: logging.Logger | None = None,
|
||||||
|
log_context: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._redis_client = redis_client
|
||||||
|
self._name = name
|
||||||
|
self._ttl_seconds = float(ttl_seconds)
|
||||||
|
self._renew_interval_seconds = (
|
||||||
|
float(renew_interval_seconds) if renew_interval_seconds is not None else max(0.1, self._ttl_seconds / 3)
|
||||||
|
)
|
||||||
|
self._logger = logger or logging.getLogger(__name__)
|
||||||
|
self._log_context = log_context
|
||||||
|
|
||||||
|
self._lock = None
|
||||||
|
self._stop_event = None
|
||||||
|
self._thread = None
|
||||||
|
self._acquired = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def acquire(self, *args: Any, **kwargs: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Acquire the lock and start auto-renew heartbeat on success.
|
||||||
|
|
||||||
|
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
|
||||||
|
"""
|
||||||
|
self._lock = self._redis_client.lock(
|
||||||
|
name=self._name,
|
||||||
|
timeout=self._ttl_seconds,
|
||||||
|
thread_local=False,
|
||||||
|
)
|
||||||
|
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||||
|
self._acquired = acquired
|
||||||
|
if acquired:
|
||||||
|
self._start_heartbeat()
|
||||||
|
return acquired
|
||||||
|
|
||||||
|
def owned(self) -> bool:
|
||||||
|
if self._lock is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(self._lock.owned())
|
||||||
|
except Exception:
|
||||||
|
# Ownership checks are best-effort and must not break callers.
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _start_heartbeat(self) -> None:
|
||||||
|
if self._lock is None:
|
||||||
|
return
|
||||||
|
if self._stop_event is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._thread = threading.Thread(
|
||||||
|
target=self._heartbeat_loop,
|
||||||
|
args=(self._lock, self._stop_event),
|
||||||
|
daemon=True,
|
||||||
|
name=f"AutoRenewRedisLock({self._name})",
|
||||||
|
)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||||
|
while not stop_event.wait(self._renew_interval_seconds):
|
||||||
|
try:
|
||||||
|
lock.reacquire()
|
||||||
|
except LockNotOwnedError:
|
||||||
|
self._logger.warning(
|
||||||
|
"Auto-renew lock is no longer owned during heartbeat%s; stop renewing.",
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except RedisError:
|
||||||
|
self._logger.warning(
|
||||||
|
"Failed to renew auto-renew lock due to Redis error%s; will retry.",
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self._logger.warning(
|
||||||
|
"Unexpected error while renewing auto-renew lock%s; will retry.",
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def release_safely(self, *, status: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Stop heartbeat and release lock. Never raises.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
|
||||||
|
"""
|
||||||
|
lock = self._lock
|
||||||
|
if lock is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._stop_heartbeat()
|
||||||
|
|
||||||
|
# Lock release errors should never mask the real error/exit code.
|
||||||
|
try:
|
||||||
|
lock.release()
|
||||||
|
except LockNotOwnedError:
|
||||||
|
self._logger.warning(
|
||||||
|
"Auto-renew lock not owned on release%s%s; ignoring.",
|
||||||
|
f" after {status} operation" if status else "",
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except RedisError:
|
||||||
|
self._logger.warning(
|
||||||
|
"Failed to release auto-renew lock due to Redis error%s%s; ignoring.",
|
||||||
|
f" after {status} operation" if status else "",
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self._logger.warning(
|
||||||
|
"Unexpected error while releasing auto-renew lock%s%s; ignoring.",
|
||||||
|
f" after {status} operation" if status else "",
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._acquired = False
|
||||||
|
|
||||||
|
def _stop_heartbeat(self) -> None:
|
||||||
|
if self._stop_event is None:
|
||||||
|
return
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._thread is not None:
|
||||||
|
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
|
||||||
|
join_timeout_seconds = max(0.5, min(5.0, self._renew_interval_seconds * 2))
|
||||||
|
self._thread.join(timeout=join_timeout_seconds)
|
||||||
|
if self._thread.is_alive():
|
||||||
|
self._logger.warning(
|
||||||
|
"Auto-renew lock heartbeat thread did not stop within %.2fs%s; ignoring.",
|
||||||
|
join_timeout_seconds,
|
||||||
|
f" ({self._log_context})" if self._log_context else "",
|
||||||
|
)
|
||||||
|
self._stop_event = None
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for AutoRenewRedisLock using real Redis via TestContainers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.auto_renew_redis_lock import AutoRenewRedisLock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||||
|
def test_auto_renew_redis_lock_renews_ttl_and_releases():
|
||||||
|
lock_name = f"test:auto_renew_lock:{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
|
||||||
|
lock = AutoRenewRedisLock(
|
||||||
|
redis_client=redis_client,
|
||||||
|
name=lock_name,
|
||||||
|
ttl_seconds=1.0,
|
||||||
|
renew_interval_seconds=0.2,
|
||||||
|
log_context="test_auto_renew_redis_lock",
|
||||||
|
)
|
||||||
|
|
||||||
|
acquired = lock.acquire(blocking=True, blocking_timeout=5)
|
||||||
|
assert acquired is True
|
||||||
|
|
||||||
|
# Wait beyond the base TTL; key should still exist due to renewal.
|
||||||
|
time.sleep(1.5)
|
||||||
|
ttl = redis_client.ttl(lock_name)
|
||||||
|
assert ttl > 0
|
||||||
|
|
||||||
|
lock.release_safely(status="successful")
|
||||||
|
|
||||||
|
# After release, the key should not exist.
|
||||||
|
assert redis_client.exists(lock_name) == 0
|
||||||
|
|
||||||
@ -4,8 +4,9 @@ import types
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import commands
|
import commands
|
||||||
|
from libs.auto_renew_redis_lock import LockNotOwnedError, RedisError
|
||||||
|
|
||||||
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 1.0
|
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
|
||||||
|
|
||||||
|
|
||||||
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
|
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
|
||||||
@ -45,7 +46,7 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
|
|||||||
|
|
||||||
lock = MagicMock()
|
lock = MagicMock()
|
||||||
lock.acquire.return_value = True
|
lock.acquire.return_value = True
|
||||||
lock.release.side_effect = commands.LockNotOwnedError("simulated")
|
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||||
commands.redis_client.lock.return_value = lock
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
def _upgrade():
|
def _upgrade():
|
||||||
@ -69,7 +70,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy
|
|||||||
|
|
||||||
lock = MagicMock()
|
lock = MagicMock()
|
||||||
lock.acquire.return_value = True
|
lock.acquire.return_value = True
|
||||||
lock.release.side_effect = commands.LockNotOwnedError("simulated")
|
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||||
commands.redis_client.lock.return_value = lock
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
_install_fake_flask_migrate(monkeypatch, lambda: None)
|
_install_fake_flask_migrate(monkeypatch, lambda: None)
|
||||||
@ -129,7 +130,7 @@ def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
|
|||||||
|
|
||||||
def _reacquire():
|
def _reacquire():
|
||||||
attempted.set()
|
attempted.set()
|
||||||
raise commands.RedisError("simulated")
|
raise RedisError("simulated")
|
||||||
|
|
||||||
lock.reacquire.side_effect = _reacquire
|
lock.reacquire.side_effect = _reacquire
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,125 @@
|
|||||||
|
"""Unit tests for enterprise service integrations.
|
||||||
|
|
||||||
|
This module covers the enterprise-only default workspace auto-join behavior:
|
||||||
|
- Enterprise mode disabled: no external calls
|
||||||
|
- Successful join / skipped join: no errors
|
||||||
|
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.enterprise.enterprise_service import (
|
||||||
|
DefaultWorkspaceJoinResult,
|
||||||
|
EnterpriseService,
|
||||||
|
try_join_default_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJoinDefaultWorkspace:
|
||||||
|
def test_join_default_workspace_success(self):
|
||||||
|
account_id = "11111111-1111-1111-1111-111111111111"
|
||||||
|
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
|
||||||
|
|
||||||
|
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||||
|
mock_send_request.return_value = response
|
||||||
|
|
||||||
|
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||||
|
|
||||||
|
assert isinstance(result, DefaultWorkspaceJoinResult)
|
||||||
|
assert result.workspace_id == response["workspace_id"]
|
||||||
|
assert result.joined is True
|
||||||
|
assert result.message == "ok"
|
||||||
|
|
||||||
|
mock_send_request.assert_called_once_with(
|
||||||
|
"POST",
|
||||||
|
"/default-workspace/members",
|
||||||
|
json={"account_id": account_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||||
|
account_id = "11111111-1111-1111-1111-111111111111"
|
||||||
|
|
||||||
|
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||||
|
mock_send_request.return_value = "not-a-dict"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid response format"):
|
||||||
|
EnterpriseService.join_default_workspace(account_id=account_id)
|
||||||
|
|
||||||
|
def test_join_default_workspace_invalid_account_id_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTryJoinDefaultWorkspace:
|
||||||
|
def test_try_join_default_workspace_enterprise_disabled_noop(self):
|
||||||
|
with (
|
||||||
|
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||||
|
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = False
|
||||||
|
|
||||||
|
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
|
||||||
|
|
||||||
|
mock_join.assert_not_called()
|
||||||
|
|
||||||
|
def test_try_join_default_workspace_successful_join_does_not_raise(self):
|
||||||
|
account_id = "11111111-1111-1111-1111-111111111111"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||||
|
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||||
|
workspace_id="22222222-2222-2222-2222-222222222222",
|
||||||
|
joined=True,
|
||||||
|
message="ok",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
try_join_default_workspace(account_id)
|
||||||
|
|
||||||
|
mock_join.assert_called_once_with(account_id=account_id)
|
||||||
|
|
||||||
|
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
|
||||||
|
account_id = "11111111-1111-1111-1111-111111111111"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||||
|
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||||
|
workspace_id="",
|
||||||
|
joined=False,
|
||||||
|
message="no default workspace configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
try_join_default_workspace(account_id)
|
||||||
|
|
||||||
|
mock_join.assert_called_once_with(account_id=account_id)
|
||||||
|
|
||||||
|
def test_try_join_default_workspace_api_failure_soft_fails(self):
|
||||||
|
account_id = "11111111-1111-1111-1111-111111111111"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||||
|
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||||
|
):
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
mock_join.side_effect = Exception("network failure")
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
try_join_default_workspace(account_id)
|
||||||
|
|
||||||
|
mock_join.assert_called_once_with(account_id=account_id)
|
||||||
|
|
||||||
|
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
|
||||||
|
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
|
||||||
|
mock_config.ENTERPRISE_ENABLED = True
|
||||||
|
|
||||||
|
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||||
|
try_join_default_workspace("not-a-uuid")
|
||||||
Loading…
Reference in New Issue
Block a user