diff --git a/api/app_factory.py b/api/app_factory.py index e9094fd8ad..49be025731 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp): def create_migrations_app() -> DifyApp: app = create_flask_app_with_configs() - from extensions import ext_database, ext_migrate + from extensions import ext_commands, ext_database, ext_migrate # Initialize only required extensions ext_database.init_app(app) ext_migrate.init_app(app) + ext_commands.init_app(app) return app diff --git a/api/commands/__init__.py b/api/commands/__init__.py index d62d0dbd7c..9d1bf7d0fe 100644 --- a/api/commands/__init__.py +++ b/api/commands/__init__.py @@ -3,6 +3,7 @@ CLI command modules extracted from `commands.py`. """ from .account import create_tenant, reset_email, reset_password +from .data_migrate import data_migrate, legacy_model_types from .plugin import ( extract_plugins, extract_unique_plugins, @@ -44,6 +45,7 @@ __all__ = [ "clear_orphaned_file_records", "convert_to_agent_apps", "create_tenant", + "data_migrate", "delete_archived_workflow_runs", "export_app_messages", "extract_plugins", @@ -52,6 +54,7 @@ __all__ = [ "fix_app_site_missing", "install_plugins", "install_rag_pipeline_plugins", + "legacy_model_types", "migrate_annotation_vector_database", "migrate_data_for_plugin", "migrate_knowledge_vector_database", diff --git a/api/commands/data_migrate.py b/api/commands/data_migrate.py new file mode 100644 index 0000000000..6e3910e619 --- /dev/null +++ b/api/commands/data_migrate.py @@ -0,0 +1,169 @@ +import io +import os +import sys +from contextlib import AbstractContextManager, nullcontext +from pathlib import Path +from typing import cast + +import click + +from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from services.legacy_model_type_migration import ( + VALID_TABLE_NAMES, + LegacyModelTypeMigrationService, + load_tenant_ids_from_file, +) + +_SUPPORTED_MODEL_TYPE_CHOICES = ( + ModelType.LLM.value, + ModelType.TEXT_EMBEDDING.value, + ModelType.RERANK.value, +) +_DEFAULT_CONCURRENCY = os.cpu_count() or 1 + + +def _normalize_multi_value_option( + values: tuple[str, ...], + *, + valid_values: tuple[str, ...], + option_name: str, +) -> tuple[str, ...]: + normalized_values: list[str] = [] + seen_values: set[str] = set() + + for value in values: + for item in value.split(","): + normalized_item = item.strip() + if not normalized_item: + continue + if normalized_item not in valid_values: + raise click.BadParameter( + f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}", + param_hint=option_name, + ) + if normalized_item in seen_values: + continue + seen_values.add(normalized_item) + normalized_values.append(normalized_item) + + return tuple(normalized_values) + + +@click.group( + "data-migrate", + help="Online data migration commands.", +) +def data_migrate() -> None: + """Namespace for production data migration commands.""" + + +@click.command( + "legacy-model-types", + help=( + "Migrate legacy provider model_type values to canonical values. " + "Default is dry-run and emits JSON lines only. " + "If --tables includes provider_model_credentials, the command may also update " + "provider_models and load_balancing_model_configs references so merged credentials stay reachable." + ), +) +@click.option( + "--apply", + is_flag=True, + default=False, + help="Apply the migration. Default is dry-run.", +) +@click.option( + "--tables", + "tables", + multiple=True, + type=str, + help=( + "Limit model_type migration to specific tables. Accepts comma-separated values or repeated flags. " + "When provider_model_credentials is selected, provider_models and " + "load_balancing_model_configs may also be updated for credential reference rewrites." + "Default to: " + ), +) +@click.option( + "--model-types", + "model_types", + multiple=True, + type=str, + help=( + "Canonical model types to migrate. Accepts comma-separated values or repeated flags. " + "Defaults to: `llm,text-embedding,rerank`" + ), +) +@click.option( + "--tenant-id-file", + type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True), + help="Optional file containing tenant ids, one per line.", +) +@click.option( + "--output", + type=click.Path(dir_okay=False, resolve_path=True, path_type=Path), + help="Optional file path for JSON lines event logs. Defaults to stdout.", +) +@click.option( + "--concurrency", + type=click.IntRange(min=1), + default=_DEFAULT_CONCURRENCY, + show_default=True, + help="Number of tenant-level worker threads to run in parallel.", +) +def legacy_model_types( + apply: bool, + tables: tuple[str, ...], + model_types: tuple[str, ...], + tenant_id_file: str | None, + output: Path | None, + concurrency: int = _DEFAULT_CONCURRENCY, +) -> None: + """ + Migrate legacy provider-related model_type values and emit JSON lines events. + """ + + normalized_tables = _normalize_multi_value_option( + tables, + valid_values=VALID_TABLE_NAMES, + option_name="--tables", + ) + normalized_model_types = _normalize_multi_value_option( + model_types, + valid_values=_SUPPORTED_MODEL_TYPE_CHOICES, + option_name="--model-types", + ) + selected_model_types = ( + tuple(ModelType.value_of(model_type) for model_type in normalized_model_types) + if normalized_model_types + else ( + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, + ) + ) + tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None + + output_context: AbstractContextManager[io.TextIOBase] + if output is None: + output_context = nullcontext(cast(io.TextIOBase, sys.stdout)) + else: + try: + output_context = output.open("w", encoding="utf-8") + except OSError as exc: + raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc + + with output_context as output_stream: + LegacyModelTypeMigrationService( + engine=db.engine, + apply=apply, + concurrency=concurrency, + output=cast(io.TextIOBase, output_stream), + tables=normalized_tables or None, + model_types=selected_model_types, + tenant_ids=tenant_ids, + ).migrate() + + +data_migrate.add_command(legacy_model_types) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index fe95cc5816..18a0c75aca 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -12,6 +12,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, convert_to_agent_apps, create_tenant, + data_migrate, delete_archived_workflow_runs, export_app_messages, extract_plugins, @@ -44,6 +45,7 @@ def init_app(app: DifyApp): convert_to_agent_apps, add_qdrant_index, create_tenant, + data_migrate, upgrade_db, fix_app_site_missing, migrate_data_for_plugin, diff --git a/api/pyproject.toml b/api/pyproject.toml index 6ed2862d6a..95f764aef7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -102,10 +102,7 @@ dify-trace-weave = { workspace = true } [tool.uv] default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false -override-dependencies = [ - "litellm>=1.83.10,<2.0.0", - "pyarrow>=23.0.1,<24.0.0", -] +override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"] [dependency-groups] diff --git a/api/services/legacy_model_type_migration.py b/api/services/legacy_model_type_migration.py new file mode 100644 index 0000000000..2de5e7f7f3 --- /dev/null +++ b/api/services/legacy_model_type_migration.py @@ -0,0 +1,2464 @@ +""" +Migrate legacy provider-related model_type values to canonical values. + +The grouped tables scan legacy candidates in id order, then reload the full business-key +group inside a transaction before deciding a winner row and the loser rows to delete. +Those grouped flows share the same dry-run/apply handling for group reloads, winner-loser +decisions, row updates, row deletes, and structured logging. Only some grouped tables +also add cache cleanup; that includes `provider_models` and +`provider_model_credentials`. Provider-model-credential groups extend that flow by +rewriting credential references in provider models and load-balancing configs before +removing loser credential rows. `load_balancing_model_configs` stays mostly row-level, +but it first deduplicates `name="__inherit__"` rows by business key before it +canonicalizes the remaining legacy rows independently with row-level cache cleanup. + +Tenant scheduling has two modes. When callers provide an explicit tenant list, the +service preserves the original tenant-scoped execution model and runs all selected tables +for each tenant. When callers omit `tenant_ids`, the service discovers tenant +ids per table and then runs only that table for the discovered tenants. Most +tables keep the active `model_types` filter in the discovery query, while +`load_balancing_model_configs` deliberately uses a whole-table tenant scan so +that query stays easy to understand. +""" + +from __future__ import annotations + +import io +import json +import sys +import threading +import traceback +import uuid +from collections.abc import Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, dataclass +from datetime import datetime +from enum import IntEnum, StrEnum +from typing import Protocol, cast + +import sqlalchemy as sa +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import Session +from sqlalchemy.sql import select + +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from graphon.model_runtime.entities.model_entities import ModelType +from libs.datetime_utils import naive_utc_now +from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, TenantDefaultModel +from models.base import TypeBase +from models.provider import ProviderModelCredential + +type ORMModel = type[TypeBase] + + +def _json_default(value: object) -> object: + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, (IntEnum, StrEnum)): + return value.value + return value + + +def _normalize_log_value(field_name: str, value: object) -> object: + if field_name == "encrypted_config" and isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + return value + + +def _normalize_log_mapping(values: dict[str, object]) -> dict[str, object]: + return {key: _normalize_log_value(key, value) for key, value in values.items()} + + +def _normalize_log_payload(value: object) -> object: + if value is None or isinstance(value, bool | int | float | str): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, (IntEnum, StrEnum)): + return value.value + if isinstance(value, dict): + return {str(key): _normalize_log_payload(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_normalize_log_payload(item) for item in value] + if isinstance(value, (set, frozenset)): + normalized_items = [_normalize_log_payload(item) for item in value] + return sorted(normalized_items, key=lambda item: json.dumps(item, sort_keys=True)) + + table_name = getattr(value, "__tablename__", None) + if isinstance(table_name, str): + return table_name + + name = getattr(value, "name", None) + if isinstance(name, str): + return name + + table = getattr(value, "table", None) + if table is not None: + referenced_table_name = getattr(table, "name", None) + if isinstance(referenced_table_name, str): + return referenced_table_name + + return f"<{type(value).__module__}.{type(value).__qualname__}>" + + +def _format_exception_stacktrace(exc: BaseException) -> str: + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + +@dataclass(frozen=True, slots=True) +class _RowWithRawModelType[T: TypeBase]: + row: T + raw_model_type: str + canonical_model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _CacheDeletePlan: + tenant_id: str + identity_id: str + cache_type: ProviderCredentialsCacheType + table_name: str + row_id: str + tx_id: str + business_key: _BusinessKey + + +@dataclass(frozen=True, slots=True) +class _BusinessKey: + """Marker base type for structured migration business keys.""" + + +class _HasRowId(Protocol): + id: object + + +class _HasRowIdAndUpdatedAt(_HasRowId, Protocol): + updated_at: datetime + + +def _normalize_error_code_string(value: object) -> str | None: + if isinstance(value, str): + normalized_value = value.strip().upper() + return normalized_value or None + return None + + +def _normalize_error_code_int(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, str): + normalized_value = value.strip() + if normalized_value.isdigit(): + return int(normalized_value) + return None + + +@dataclass(frozen=True, slots=True) +class _ProviderModelBusinessKey(_BusinessKey): + """unique index: unique_provider_model_name""" + + tenant_id: str + provider_name: str + model_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _TenantDefaultModelBusinessKey(_BusinessKey): + """unique index: unique_tenant_default_model_type""" + + tenant_id: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _ProviderModelSettingBusinessKey(_BusinessKey): + """Although `ProviderModelSetting` does not have the unique index + (tenant_id, provider_name. model_name, model_type). The acutal business logic + relies on this uniqueness property heavily. + """ + + tenant_id: str + provider_name: str + model_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _LoadBalancingModelConfigInheritBusinessKey(_BusinessKey): + """Business key for `name="__inherit__"` load-balancing configs.""" + + tenant_id: str + provider_name: str + model_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _ProviderModelCredentialBusinessKey(_BusinessKey): + """Although `ProviderModelCredential` does not have the unique index + (tenant_id, provider_name. model_name, model_type, credential_name). + The acutal business logic implies it.""" + + tenant_id: str + provider_name: str + model_name: str + credential_name: str + model_type: ModelType + + +@dataclass(frozen=True, slots=True) +class _ProviderModelGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[ProviderModel] | None + loser_rows: list[_RowWithRawModelType[ProviderModel]] + + +@dataclass(frozen=True, slots=True) +class _TenantDefaultModelGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[TenantDefaultModel] | None + loser_rows: list[_RowWithRawModelType[TenantDefaultModel]] + + +@dataclass(frozen=True, slots=True) +class _ProviderModelSettingGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[ProviderModelSetting] | None + loser_rows: list[_RowWithRawModelType[ProviderModelSetting]] + + +@dataclass(frozen=True, slots=True) +class _LoadBalancingModelConfigInheritGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[LoadBalancingModelConfig] | None + loser_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] + + +@dataclass(frozen=True, slots=True) +class _ProviderModelReferenceRewritePlan: + row_id: str + old_credential_id: str + new_credential_id: str + + +@dataclass(frozen=True, slots=True) +class _LoadBalancingCredentialRewritePlan: + row_id: str + old_credential_id: str | None + old_name: str + old_encrypted_config: str | None + new_credential_id: str + new_name: str + new_encrypted_config: str | None + + +@dataclass(frozen=True, slots=True) +class _ProviderModelCredentialGroupPlan: + group_row_ids: list[str] + winner: _RowWithRawModelType[ProviderModelCredential] | None + loser_rows: list[_RowWithRawModelType[ProviderModelCredential]] + provider_model_rewrites: list[_ProviderModelReferenceRewritePlan] + load_balancing_rewrites: list[_LoadBalancingCredentialRewritePlan] + + +VALID_TABLE_NAMES: tuple[str, ...] = ( + ProviderModel.__tablename__, + TenantDefaultModel.__tablename__, + ProviderModelSetting.__tablename__, + LoadBalancingModelConfig.__tablename__, + ProviderModelCredential.__tablename__, +) + +_SUPPORTED_MODEL_TYPES: tuple[ModelType, ...] = ( + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, +) +_CANONICAL_TO_LEGACY: dict[ModelType, tuple[str, ...]] = { + ModelType.LLM: ("text-generation",), + ModelType.TEXT_EMBEDDING: ("embeddings",), + ModelType.RERANK: ("reranking",), +} +_LEGACY_TO_CANONICAL: dict[str, ModelType] = { + legacy_value: canonical_model_type + for canonical_model_type, legacy_values in _CANONICAL_TO_LEGACY.items() + for legacy_value in legacy_values +} +_POSTGRES_LOCK_TIMEOUT_SQLSTATES: frozenset[str] = frozenset({"55P03"}) +_MYSQL_LOCK_TIMEOUT_ERRNOS: frozenset[int] = frozenset({1205}) +_LOCK_TIMEOUT_FALLBACK_MESSAGES: tuple[str, ...] = ( + "canceling statement due to lock timeout", + "lock wait timeout exceeded", +) +_RAW_MODEL_TYPE_COLUMN = "_raw_model_type" + + +def _selected_legacy_values(model_types: Sequence[ModelType]) -> list[str]: + legacy_values: list[str] = [] + for model_type in model_types: + legacy_values.extend(_CANONICAL_TO_LEGACY[model_type]) + return legacy_values + + +def _selected_model_type_values(model_types: Sequence[ModelType]) -> list[str]: + model_type_values: list[str] = [] + for model_type in model_types: + model_type_values.append(model_type.value) + model_type_values.extend(_CANONICAL_TO_LEGACY[model_type]) + return list(dict.fromkeys(model_type_values)) + + +def _session_factory(engine: sa.Engine) -> Session: + return Session(bind=engine, expire_on_commit=False) + + +class _ThreadSafeLineWriter(io.TextIOBase): + """ + Serialize line-oriented writes to a shared text stream across tenant workers. + + `Migration._log_event` writes one JSON document per `print(..., flush=True)` call. The + wrapper buffers fragments per thread until a newline arrives, then emits the full line + while holding a process-local lock so concurrent tenants cannot interleave bytes. + """ + + _stream: io.TextIOBase + _lock: threading.Lock + _local: threading.local + + def __init__(self, stream: io.TextIOBase) -> None: + super().__init__() + self._stream = stream + self._lock = threading.Lock() + self._local = threading.local() + + def writable(self) -> bool: + return True + + def write(self, text: str) -> int: + if not text: + return 0 + + buffered_text = self._buffer + text + lines = buffered_text.splitlines(keepends=True) + remainder = "" + if lines and not lines[-1].endswith(("\n", "\r")): + remainder = lines.pop() + + for line in lines: + self._write_line(line) + + self._buffer = remainder + return len(text) + + def flush(self) -> None: + buffered_text = self._buffer + if buffered_text: + self._write_line(buffered_text) + self._buffer = "" + + with self._lock: + self._stream.flush() + + @property + def _buffer(self) -> str: + return cast(str, getattr(self._local, "buffer", "")) + + @_buffer.setter + def _buffer(self, value: str) -> None: + self._local.buffer = value + + def _write_line(self, text: str) -> None: + with self._lock: + self._stream.write(text) + + +class LegacyModelTypeMigrationService: + """ + Migrate legacy provider-related model_type values to canonical values. + + The command can scope the migration by table, tenant, and canonical model type. When + `provider_model_credentials` is selected, that migration also rewrites references in + `provider_models` and `load_balancing_model_configs`. Tenant migrations can run in a + thread pool; JSONL output remains line-safe through a shared synchronized writer. + + If `tenant_ids` is omitted, tenant discovery becomes table-scoped: each selected ORM + model loads its own tenant ids, then only that table is dispatched for those tenants. + Most tables keep the active model-type filter in discovery, while + `load_balancing_model_configs` intentionally uses the whole table so the tenant query + stays simple. This still avoids merging tenant ids across unrelated tables. + """ + + _engine: sa.Engine + _apply: bool + _concurrency: int + _output: io.TextIOBase + _model_types: tuple[ModelType, ...] + _orm_models: tuple[ORMModel, ...] + _tenant_ids: tuple[str, ...] | None + + def __init__( + self, + engine: sa.Engine, + *, + apply: bool = False, + concurrency: int = 1, + output: io.TextIOBase | None = None, + tables: Sequence[str] | None = None, + model_types: Sequence[ModelType] = _SUPPORTED_MODEL_TYPES, + tenant_ids: Sequence[str] | None = None, + ) -> None: + if concurrency < 1: + raise ValueError("concurrency must be greater than or equal to 1") + + self._engine = engine + self._apply = apply + self._concurrency = concurrency + self._output = cast(io.TextIOBase, sys.stdout if output is None else output) + self._model_types = tuple(dict.fromkeys(model_types)) + self._orm_models = self._resolve_models(tables) + self._tenant_ids = tuple(dict.fromkeys(tenant_ids)) if tenant_ids is not None else None + + def _resolve_models(self, tables: Sequence[str] | None) -> tuple[ORMModel, ...]: + if tables is None: + return ( + ProviderModel, + TenantDefaultModel, + ProviderModelSetting, + LoadBalancingModelConfig, + ProviderModelCredential, + ) + + ordered_models: list[ORMModel] = [] + seen_tables: set[str] = set() + for table_name in tables: + if table_name in seen_tables: + continue + seen_tables.add(table_name) + if table_name == ProviderModel.__tablename__: + ordered_models.append(ProviderModel) + elif table_name == TenantDefaultModel.__tablename__: + ordered_models.append(TenantDefaultModel) + elif table_name == ProviderModelSetting.__tablename__: + ordered_models.append(ProviderModelSetting) + elif table_name == LoadBalancingModelConfig.__tablename__: + ordered_models.append(LoadBalancingModelConfig) + elif table_name == ProviderModelCredential.__tablename__: + ordered_models.append(ProviderModelCredential) + else: + raise ValueError(f"invalid table name: {table_name}") + return tuple(ordered_models) + + def migrate(self) -> None: + output = _ThreadSafeLineWriter(self._output) + if self._tenant_ids is not None: + self._migrate_explicit_tenants(output) + return + + self._migrate_tables_with_discovered_tenants(output) + + def _migrate_explicit_tenants(self, output: io.TextIOBase) -> None: + tenant_ids = self._tenant_ids + if not tenant_ids: + return + + self._run_migrations_for_tenants(tenant_ids, self._orm_models, output) + + def _migrate_tables_with_discovered_tenants(self, output: io.TextIOBase) -> None: + for orm_model in self._orm_models: + tenant_ids = self._load_tenant_ids_for_model(orm_model) + if not tenant_ids: + continue + self._run_migrations_for_tenants(tenant_ids, (orm_model,), output) + + def _run_migrations_for_tenants( + self, + tenant_ids: Sequence[str], + orm_models: Sequence[ORMModel], + output: io.TextIOBase, + ) -> None: + if self._concurrency == 1 or len(tenant_ids) == 1: + for tenant_id in tenant_ids: + self._run_tenant_migration(tenant_id, orm_models, output) + return + + with ThreadPoolExecutor(max_workers=min(self._concurrency, len(tenant_ids))) as executor: + futures = [ + executor.submit(self._run_tenant_migration, tenant_id, orm_models, output) for tenant_id in tenant_ids + ] + for future in as_completed(futures): + future.result() + + def _run_tenant_migration( + self, + tenant_id: str, + orm_models: Sequence[ORMModel], + output: io.TextIOBase, + ) -> None: + """ + Execute one tenant migration with the shared, line-synchronized output stream. + """ + + Migration( + tenant_id=tenant_id, + engine=self._engine, + apply=self._apply, + output=output, + model_types=self._model_types, + orm_models=orm_models, + ).run() + + def _load_tenant_ids_for_model(self, orm_model: ORMModel) -> tuple[str, ...]: + """ + Discover only the tenants that have candidate rows for the current table. + + In automatic tenant mode we keep discovery table-scoped so large shared tenant + populations do not force empty work for unrelated tables. Most table queries + still apply the active `model_types` filter before scheduling migrations, while + `load_balancing_model_configs` intentionally trades a wider tenant set for a + simpler discovery query. + """ + + legacy_model_type_values = _selected_legacy_values(self._model_types) + with _session_factory(self._engine) as session: + if orm_model is ProviderModel: + tenant_ids = ( + session.execute( + select(ProviderModel.tenant_id) + .where(sa.type_coerce(ProviderModel.model_type, sa.String()).in_(legacy_model_type_values)) + .distinct() + .order_by(ProviderModel.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is TenantDefaultModel: + tenant_ids = ( + session.execute( + select(TenantDefaultModel.tenant_id) + .where(sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(legacy_model_type_values)) + .distinct() + .order_by(TenantDefaultModel.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is ProviderModelSetting: + tenant_ids = ( + session.execute( + select(ProviderModelSetting.tenant_id) + .where( + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(legacy_model_type_values) + ) + .distinct() + .order_by(ProviderModelSetting.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is LoadBalancingModelConfig: + # Deliberately discover tenants from the whole table so the query stays + # easier to understand than the legacy/canonical mixed-row filter. + tenant_ids = ( + session.execute( + select(LoadBalancingModelConfig.tenant_id) + .distinct() + .order_by(LoadBalancingModelConfig.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is ProviderModelCredential: + tenant_ids = ( + session.execute( + select(ProviderModelCredential.tenant_id) + .where( + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_( + legacy_model_type_values + ) + ) + .distinct() + .order_by(ProviderModelCredential.tenant_id.asc()) + ) + .scalars() + .all() + ) + else: + raise ValueError(f"unsupported orm model: {orm_model}") + + return tuple(tenant_ids) + + +class Migration: + """ + Execute the migration for one tenant. + + The implementation is intentionally table-specific. Each table has its own scan function + and its own apply/dry-run path so the online migration logic stays explicit and auditable. + """ + + _tenant_id: str + _engine: sa.Engine + _apply: bool + _output: io.TextIOBase + _model_types: tuple[ModelType, ...] + _orm_models: tuple[ORMModel, ...] + _batch_size: int + _lock_timeout_seconds: int + + def __init__( + self, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: Sequence[ModelType], + orm_models: Sequence[ORMModel], + ) -> None: + self._tenant_id = tenant_id + self._engine = engine + self._apply = apply + self._output = output + self._model_types = tuple(model_types) + self._orm_models = tuple(orm_models) + self._batch_size = 200 + self._lock_timeout_seconds = 5 + + def run(self) -> None: + self._log_event( + "tenant_started", + "Started tenant migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "tables": [model.__tablename__ for model in self._orm_models], + "model_types": [model_type.value for model_type in self._model_types], + }, + ) + + for orm_model in self._orm_models: + if orm_model is ProviderModel: + self._migrate_provider_models() + elif orm_model is TenantDefaultModel: + self._migrate_tenant_default_models() + elif orm_model is ProviderModelSetting: + self._migrate_provider_model_settings() + elif orm_model is LoadBalancingModelConfig: + self._migrate_load_balancing_model_configs() + elif orm_model is ProviderModelCredential: + self._migrate_provider_model_credentials() + + self._log_event( + "tenant_completed", + "Completed tenant migration.", + {"tenant_id": self._tenant_id, "apply": self._apply}, + ) + + def _selected_legacy_values(self) -> list[str]: + return _selected_legacy_values(self._model_types) + + def _selected_model_type_values(self) -> list[str]: + return _selected_model_type_values(self._model_types) + + def _allowed_values_for_canonical_model_type(self, canonical_model_type: ModelType) -> tuple[str, ...]: + return (*_CANONICAL_TO_LEGACY[canonical_model_type], canonical_model_type.value) + + def _normalize_selected_model_type(self, raw_model_type: str) -> ModelType | None: + canonical_model_type = _LEGACY_TO_CANONICAL.get(raw_model_type) + if canonical_model_type is not None: + return canonical_model_type + + try: + parsed_model_type = ModelType(raw_model_type) + except ValueError: + return None + + if parsed_model_type not in self._model_types: + return None + return parsed_model_type + + def _has_legacy_rows[T: TypeBase](self, rows: Sequence[_RowWithRawModelType[T]]) -> bool: + return any(row.raw_model_type in _LEGACY_TO_CANONICAL for row in rows) + + def _select_winner[T: TypeBase](self, rows: Sequence[_RowWithRawModelType[T]]) -> _RowWithRawModelType[T]: + return max(rows, key=lambda row: self._winner_sort_key(row.row)) + + def _winner_sort_key(self, row: TypeBase) -> tuple[datetime, str]: + typed_row = cast(_HasRowIdAndUpdatedAt, row) + return typed_row.updated_at, str(typed_row.id) + + def _row_id(self, row: TypeBase) -> str: + return str(cast(_HasRowId, row).id) + + def _new_tx_id(self) -> str: + return str(uuid.uuid4()) + + def _migrate_provider_models(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + {"tenant_id": self._tenant_id, "apply": self._apply, "table_name": ProviderModel.__tablename__}, + ) + + seen_business_keys: dict[_ProviderModelBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_provider_model_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _ProviderModelBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_provider_model_group(candidate, business_key) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModel.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_provider_model_candidates(self, last_id: str | None) -> list[_RowWithRawModelType[ProviderModel]]: + raw_model_type = sa.type_coerce(ProviderModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(ProviderModel, raw_model_type) + .where( + ProviderModel.tenant_id == self._tenant_id, + sa.type_coerce(ProviderModel.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(ProviderModel.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(ProviderModel.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[ProviderModel]] = [] + for provider_model, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": provider_model.id, "table_name": provider_model.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_provider_model_group( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModel], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[ProviderModel]]: + raw_model_type = sa.type_coerce(ProviderModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(ProviderModel, raw_model_type) + .where( + ProviderModel.tenant_id == candidate.row.tenant_id, + ProviderModel.provider_name == candidate.row.provider_name, + ProviderModel.model_name == candidate.row.model_name, + sa.type_coerce(ProviderModel.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(ProviderModel.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[ProviderModel]] = [] + for provider_model, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_provider_model_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModel], + *, + lock_rows: bool, + ) -> _ProviderModelGroupPlan: + rows = self._load_provider_model_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _ProviderModelGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _ProviderModelGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_provider_model_group_plan( + self, + plan: _ProviderModelGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + cache_plans: list[_CacheDeletePlan] = [] + for loser in plan.loser_rows: + if self._apply: + session.execute(sa.delete(ProviderModel).where(ProviderModel.id == str(loser.row.id))) + self._log_row_deleted( + ProviderModel.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=str(loser.row.id), + cache_type=ProviderCredentialsCacheType.MODEL, + table_name=ProviderModel.__tablename__, + row_id=str(loser.row.id), + tx_id=tx_id, + business_key=business_key, + ) + ) + + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(ProviderModel) + .where(ProviderModel.id == str(plan.winner.row.id)) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + ProviderModel.__tablename__, + str(plan.winner.row.id), + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=str(plan.winner.row.id), + cache_type=ProviderCredentialsCacheType.MODEL, + table_name=ProviderModel.__tablename__, + row_id=str(plan.winner.row.id), + tx_id=tx_id, + business_key=business_key, + ) + ) + + self._log_cache_plans(cache_plans, apply=self._apply) + self._log_group_processed( + ProviderModel.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_provider_model_group( + self, + candidate: _RowWithRawModelType[ProviderModel], + business_key: _ProviderModelBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_provider_model_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_provider_model_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + ProviderModel.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + + return group_row_ids + + def _migrate_tenant_default_models(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + {"tenant_id": self._tenant_id, "apply": self._apply, "table_name": TenantDefaultModel.__tablename__}, + ) + + seen_business_keys: dict[_TenantDefaultModelBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_tenant_default_model_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _TenantDefaultModelBusinessKey( + tenant_id=candidate.row.tenant_id, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_tenant_default_model_group(candidate, business_key) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": TenantDefaultModel.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_tenant_default_model_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[TenantDefaultModel]]: + raw_model_type = sa.type_coerce(TenantDefaultModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(TenantDefaultModel, raw_model_type) + .where( + TenantDefaultModel.tenant_id == self._tenant_id, + sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(TenantDefaultModel.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(TenantDefaultModel.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[TenantDefaultModel]] = [] + for tenant_default_model, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": tenant_default_model.id, "table_name": tenant_default_model.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=tenant_default_model, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_tenant_default_model_group( + self, + session: Session, + candidate: _RowWithRawModelType[TenantDefaultModel], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[TenantDefaultModel]]: + raw_model_type = sa.type_coerce(TenantDefaultModel.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(TenantDefaultModel, raw_model_type) + .where( + TenantDefaultModel.tenant_id == candidate.row.tenant_id, + sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(TenantDefaultModel.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[TenantDefaultModel]] = [] + for tenant_default_model, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=tenant_default_model, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_tenant_default_model_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[TenantDefaultModel], + *, + lock_rows: bool, + ) -> _TenantDefaultModelGroupPlan: + rows = self._load_tenant_default_model_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _TenantDefaultModelGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _TenantDefaultModelGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_tenant_default_model_group_plan( + self, + plan: _TenantDefaultModelGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + for loser in plan.loser_rows: + if self._apply: + session.execute(sa.delete(TenantDefaultModel).where(TenantDefaultModel.id == str(loser.row.id))) + self._log_row_deleted( + TenantDefaultModel.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(TenantDefaultModel) + .where(TenantDefaultModel.id == str(plan.winner.row.id)) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + TenantDefaultModel.__tablename__, + str(plan.winner.row.id), + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + + self._log_group_processed( + TenantDefaultModel.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_tenant_default_model_group( + self, + candidate: _RowWithRawModelType[TenantDefaultModel], + business_key: _TenantDefaultModelBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_tenant_default_model_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_tenant_default_model_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + TenantDefaultModel.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + return group_row_ids + + def _migrate_provider_model_settings(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + {"tenant_id": self._tenant_id, "apply": self._apply, "table_name": ProviderModelSetting.__tablename__}, + ) + + seen_business_keys: dict[_ProviderModelSettingBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_provider_model_setting_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _ProviderModelSettingBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_provider_model_setting_group(candidate, business_key) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModelSetting.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_provider_model_setting_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[ProviderModelSetting]]: + raw_model_type = sa.type_coerce(ProviderModelSetting.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(ProviderModelSetting, raw_model_type) + .where( + ProviderModelSetting.tenant_id == self._tenant_id, + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(ProviderModelSetting.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(ProviderModelSetting.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[ProviderModelSetting]] = [] + for provider_model_setting, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": provider_model_setting.id, "table_name": provider_model_setting.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_setting, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_provider_model_setting_group( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelSetting], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[ProviderModelSetting]]: + raw_model_type = sa.type_coerce(ProviderModelSetting.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(ProviderModelSetting, raw_model_type) + .where( + ProviderModelSetting.tenant_id == candidate.row.tenant_id, + ProviderModelSetting.provider_name == candidate.row.provider_name, + ProviderModelSetting.model_name == candidate.row.model_name, + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(ProviderModelSetting.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[ProviderModelSetting]] = [] + for provider_model_setting, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_setting, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_provider_model_setting_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelSetting], + *, + lock_rows: bool, + ) -> _ProviderModelSettingGroupPlan: + rows = self._load_provider_model_setting_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _ProviderModelSettingGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _ProviderModelSettingGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_provider_model_setting_group_plan( + self, + plan: _ProviderModelSettingGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + for loser in plan.loser_rows: + if self._apply: + session.execute(sa.delete(ProviderModelSetting).where(ProviderModelSetting.id == str(loser.row.id))) + self._log_row_deleted( + ProviderModelSetting.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(ProviderModelSetting) + .where(ProviderModelSetting.id == str(plan.winner.row.id)) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + ProviderModelSetting.__tablename__, + str(plan.winner.row.id), + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + + self._log_group_processed( + ProviderModelSetting.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_provider_model_setting_group( + self, + candidate: _RowWithRawModelType[ProviderModelSetting], + business_key: _ProviderModelSettingBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_provider_model_setting_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_provider_model_setting_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + ProviderModelSetting.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + return group_row_ids + + def _migrate_load_balancing_model_configs(self) -> None: + """ + Migrate load-balancing configs row by row. + + This table first deduplicates `name="__inherit__"` rows per normalized + `(tenant_id, provider_name, model_name, model_type)` business key, then + canonicalizes the remaining legacy rows independently. The pre-pass must run + first so a legacy/canonical `__inherit__` pair keeps only the newest row before + the row-level canonicalization would collapse them onto the same canonical key. + """ + self._log_event( + "table_started", + "Started table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": LoadBalancingModelConfig.__tablename__, + }, + ) + + processed_inherit_groups = self._deduplicate_inherit_load_balancing_model_configs() + processed_rows = 0 + last_id: str | None = None + + while True: + candidates = self._load_load_balancing_model_config_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + processed_rows += 1 + self._process_load_balancing_model_config_row(candidate) + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": LoadBalancingModelConfig.__tablename__, + "processed_inherit_groups": processed_inherit_groups, + "processed_rows": processed_rows, + }, + ) + + def _deduplicate_inherit_load_balancing_model_configs(self) -> int: + seen_business_keys: dict[_LoadBalancingModelConfigInheritBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_load_balancing_inherit_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _LoadBalancingModelConfigInheritBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_load_balancing_inherit_group(candidate, business_key) + processed_groups += 1 + + return processed_groups + + def _load_load_balancing_inherit_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(LoadBalancingModelConfig, raw_model_type) + .where( + LoadBalancingModelConfig.tenant_id == self._tenant_id, + LoadBalancingModelConfig.name == "__inherit__", + sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_( + self._selected_model_type_values() + ), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(LoadBalancingModelConfig.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = [] + for load_balancing_model_config, raw_value in rows: + raw_model_type_value = str(raw_value) + canonical_model_type = self._normalize_selected_model_type(raw_model_type_value) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={ + "id": load_balancing_model_config.id, + "table_name": load_balancing_model_config.__tablename__, + }, + ) + continue + + wrapped_rows.append( + _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=raw_model_type_value, + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_load_balancing_inherit_group( + self, + session: Session, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(LoadBalancingModelConfig, raw_model_type) + .where( + LoadBalancingModelConfig.tenant_id == candidate.row.tenant_id, + LoadBalancingModelConfig.provider_name == candidate.row.provider_name, + LoadBalancingModelConfig.model_name == candidate.row.model_name, + LoadBalancingModelConfig.name == "__inherit__", + sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = [] + for load_balancing_model_config, raw_value in rows: + raw_model_type_value = str(raw_value) + canonical_model_type = self._normalize_selected_model_type(raw_model_type_value) + if canonical_model_type is None: + continue + wrapped_rows.append( + _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=raw_model_type_value, + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _build_load_balancing_inherit_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + *, + lock_rows: bool, + ) -> _LoadBalancingModelConfigInheritGroupPlan: + rows = self._load_load_balancing_inherit_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if len(rows) <= 1: + return _LoadBalancingModelConfigInheritGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[]) + + winner = self._select_winner(rows) + return _LoadBalancingModelConfigInheritGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=[row for row in rows if row.row.id != winner.row.id], + ) + + def _emit_load_balancing_inherit_group_plan( + self, + plan: _LoadBalancingModelConfigInheritGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _LoadBalancingModelConfigInheritBusinessKey, + ) -> None: + if plan.winner is None: + return + + cache_plans: list[_CacheDeletePlan] = [] + for loser in plan.loser_rows: + loser_row_id = str(loser.row.id) + if self._apply: + session.execute(sa.delete(LoadBalancingModelConfig).where(LoadBalancingModelConfig.id == loser_row_id)) + self._log_row_deleted( + LoadBalancingModelConfig.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=str(plan.winner.row.id), + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=loser_row_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + table_name=LoadBalancingModelConfig.__tablename__, + row_id=loser_row_id, + tx_id=tx_id, + business_key=business_key, + ) + ) + + self._log_cache_plans(cache_plans, apply=self._apply) + self._log_group_processed( + LoadBalancingModelConfig.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_load_balancing_inherit_group( + self, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + business_key: _LoadBalancingModelConfigInheritBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_load_balancing_inherit_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_load_balancing_inherit_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + LoadBalancingModelConfig.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + + return group_row_ids + + def _load_load_balancing_model_config_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(LoadBalancingModelConfig, raw_model_type) + .where( + LoadBalancingModelConfig.tenant_id == self._tenant_id, + sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_( + self._selected_legacy_values() + ), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(LoadBalancingModelConfig.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = [] + for load_balancing_model_config, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={ + "id": load_balancing_model_config.id, + "table_name": load_balancing_model_config.__tablename__, + }, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _reload_load_balancing_model_config_candidate( + self, + session: Session, + candidate: _RowWithRawModelType[LoadBalancingModelConfig], + *, + lock_rows: bool, + ) -> _RowWithRawModelType[LoadBalancingModelConfig] | None: + raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = select(LoadBalancingModelConfig, raw_model_type).where( + LoadBalancingModelConfig.id == candidate.row.id, + LoadBalancingModelConfig.tenant_id == self._tenant_id, + ) + if lock_rows: + stmt = stmt.with_for_update() + + row = session.execute(stmt).first() + if row is None: + return None + + load_balancing_model_config, raw_value = row + raw_model_type_value = str(raw_value) + canonical_model_type = _LEGACY_TO_CANONICAL.get(raw_model_type_value) + if canonical_model_type is None: + return None + + return _RowWithRawModelType( + row=load_balancing_model_config, + raw_model_type=raw_model_type_value, + canonical_model_type=canonical_model_type, + ) + + def _log_load_balancing_model_config_cache_cleanup( + self, + *, + row_id: str, + tx_id: str, + ) -> None: + attrs = { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": LoadBalancingModelConfig.__tablename__, + "id": row_id, + "cache_type": ProviderCredentialsCacheType.LOAD_BALANCING_MODEL.value, + "tx_id": tx_id, + } + if not self._apply: + self._log_event( + "cache_delete_planned", + "Would delete related cache entry in apply mode.", + attrs, + ) + return + + try: + ProviderCredentialsCache( + tenant_id=self._tenant_id, + identity_id=row_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ).delete() + self._log_event("cache_deleted", "Deleted related cache entry.", attrs) + except Exception as exc: + self._log_exception_event( + "cache_delete_failed", + "Failed to delete related cache entry.", + attrs, + exc, + ) + + def _process_load_balancing_model_config_row( + self, candidate: _RowWithRawModelType[LoadBalancingModelConfig] + ) -> None: + tx_id = self._new_tx_id() + processed_row_id: str | None = None + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + current_row = self._reload_load_balancing_model_config_candidate(session, candidate, lock_rows=True) + if current_row is None: + return + processed_row_id = str(current_row.row.id) + + if self._apply: + session.execute( + sa.update(LoadBalancingModelConfig) + .where(LoadBalancingModelConfig.id == processed_row_id) + .values(model_type=current_row.canonical_model_type.value) + ) + self._log_row_updated( + LoadBalancingModelConfig.__tablename__, + processed_row_id, + {"model_type": current_row.raw_model_type}, + {"model_type": current_row.canonical_model_type.value}, + tx_id=tx_id, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + LoadBalancingModelConfig.__tablename__, + str(candidate.row.id), + tx_id, + None, + exc, + ) + return + raise + + if processed_row_id is not None: + self._log_load_balancing_model_config_cache_cleanup(row_id=processed_row_id, tx_id=tx_id) + + def _migrate_provider_model_credentials(self) -> None: + self._log_event( + "table_started", + "Started table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModelCredential.__tablename__, + }, + ) + + seen_business_keys: dict[_ProviderModelCredentialBusinessKey, list[str]] = {} + processed_groups = 0 + last_id: str | None = None + + while True: + candidates = self._load_provider_model_credential_candidates(last_id) + if not candidates: + break + + for candidate in candidates: + last_id = str(candidate.row.id) + business_key = _ProviderModelCredentialBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + credential_name=candidate.row.credential_name, + model_type=candidate.canonical_model_type, + ) + if business_key in seen_business_keys: + continue + + seen_business_keys[business_key] = self._process_provider_model_credential_group( + candidate, + business_key, + ) + processed_groups += 1 + + self._log_event( + "table_completed", + "Completed table migration.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": ProviderModelCredential.__tablename__, + "processed_groups": processed_groups, + }, + ) + + def _load_provider_model_credential_candidates( + self, last_id: str | None + ) -> list[_RowWithRawModelType[ProviderModelCredential]]: + raw_model_type = sa.type_coerce(ProviderModelCredential.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + with _session_factory(self._engine) as session: + stmt = ( + select(ProviderModelCredential, raw_model_type) + .where( + ProviderModelCredential.tenant_id == self._tenant_id, + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_(self._selected_legacy_values()), + ) + .order_by(ProviderModelCredential.id.asc()) + .limit(self._batch_size) + ) + if last_id is not None: + stmt = stmt.where(ProviderModelCredential.id > last_id) + rows = session.execute(stmt).all() + + wrapped_rows: list[_RowWithRawModelType[ProviderModelCredential]] = [] + for provider_model_credential, raw_value in rows: + canonical_model_type = _LEGACY_TO_CANONICAL.get(str(raw_value)) + if canonical_model_type is None: + self._log_event( + event="invalid_model_type", + message=f"invalid model type: {raw_value}", + attrs={"id": provider_model_credential.id, "table_name": provider_model_credential.__tablename__}, + ) + continue + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_credential, + raw_model_type=str(raw_value), + canonical_model_type=canonical_model_type, + ) + ) + return wrapped_rows + + def _load_provider_model_credential_group( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelCredential], + *, + lock_rows: bool, + ) -> list[_RowWithRawModelType[ProviderModelCredential]]: + raw_model_type = sa.type_coerce(ProviderModelCredential.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN) + stmt = ( + select(ProviderModelCredential, raw_model_type) + .where( + ProviderModelCredential.tenant_id == candidate.row.tenant_id, + ProviderModelCredential.provider_name == candidate.row.provider_name, + ProviderModelCredential.model_name == candidate.row.model_name, + ProviderModelCredential.credential_name == candidate.row.credential_name, + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_( + self._allowed_values_for_canonical_model_type(candidate.canonical_model_type) + ), + ) + .order_by(ProviderModelCredential.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rows = session.execute(stmt).all() + wrapped_rows: list[_RowWithRawModelType[ProviderModelCredential]] = [] + for provider_model_credential, raw_value in rows: + raw_model_type_value = str(raw_value) + wrapped_rows.append( + _RowWithRawModelType( + row=provider_model_credential, + raw_model_type=raw_model_type_value, + canonical_model_type=_LEGACY_TO_CANONICAL.get( + raw_model_type_value, + candidate.canonical_model_type, + ), + ) + ) + return wrapped_rows + + def _build_provider_model_credential_group_plan( + self, + session: Session, + candidate: _RowWithRawModelType[ProviderModelCredential], + *, + lock_rows: bool, + ) -> _ProviderModelCredentialGroupPlan: + rows = self._load_provider_model_credential_group(session, candidate, lock_rows=lock_rows) + group_row_ids = [str(row.row.id) for row in rows] + if not self._has_legacy_rows(rows): + return _ProviderModelCredentialGroupPlan( + group_row_ids=group_row_ids, + winner=None, + loser_rows=[], + provider_model_rewrites=[], + load_balancing_rewrites=[], + ) + + winner = self._select_winner(rows) + loser_rows = [row for row in rows if row.row.id != winner.row.id] + return _ProviderModelCredentialGroupPlan( + group_row_ids=group_row_ids, + winner=winner, + loser_rows=loser_rows, + provider_model_rewrites=self._plan_provider_model_reference_rewrites( + session, + winner, + loser_rows, + lock_rows=lock_rows, + ), + load_balancing_rewrites=self._plan_load_balancing_reference_rewrites( + session, + winner, + loser_rows, + lock_rows=lock_rows, + ), + ) + + def _emit_provider_model_reference_rewrites( + self, + session: Session, + rewrites: Sequence[_ProviderModelReferenceRewritePlan], + *, + winner_credential_id: str, + loser_credential_ids: Sequence[str], + tx_id: str, + business_key: _BusinessKey, + ) -> list[_CacheDeletePlan]: + cache_plans: list[_CacheDeletePlan] = [] + for rewrite in rewrites: + if self._apply: + session.execute( + sa.update(ProviderModel) + .where(ProviderModel.id == rewrite.row_id) + .values(credential_id=rewrite.new_credential_id) + ) + self._log_row_updated( + ProviderModel.__tablename__, + rewrite.row_id, + {"credential_id": rewrite.old_credential_id}, + {"credential_id": rewrite.new_credential_id}, + tx_id=tx_id, + business_key=business_key, + rewrite_source={ + "rewrite_kind": "credential_reference", + "winner_credential_id": winner_credential_id, + "loser_credential_ids": list(loser_credential_ids), + }, + ) + + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=rewrite.row_id, + cache_type=ProviderCredentialsCacheType.MODEL, + table_name=ProviderModel.__tablename__, + row_id=rewrite.row_id, + tx_id=tx_id, + business_key=business_key, + ) + ) + return cache_plans + + def _emit_load_balancing_reference_rewrites( + self, + session: Session, + rewrites: Sequence[_LoadBalancingCredentialRewritePlan], + *, + winner_credential_id: str, + loser_credential_ids: Sequence[str], + tx_id: str, + business_key: _BusinessKey, + ) -> list[_CacheDeletePlan]: + cache_plans: list[_CacheDeletePlan] = [] + for rewrite in rewrites: + if self._apply: + session.execute( + sa.update(LoadBalancingModelConfig) + .where(LoadBalancingModelConfig.id == rewrite.row_id) + .values( + credential_id=rewrite.new_credential_id, + name=rewrite.new_name, + encrypted_config=rewrite.new_encrypted_config, + ) + ) + + self._log_row_updated( + LoadBalancingModelConfig.__tablename__, + rewrite.row_id, + { + "credential_id": rewrite.old_credential_id, + "encrypted_config": rewrite.old_encrypted_config, + "name": rewrite.old_name, + }, + { + "credential_id": rewrite.new_credential_id, + "encrypted_config": rewrite.new_encrypted_config, + "name": rewrite.new_name, + }, + tx_id=tx_id, + business_key=business_key, + rewrite_source={ + "rewrite_kind": "credential_reference", + "winner_credential_id": winner_credential_id, + "loser_credential_ids": list(loser_credential_ids), + }, + ) + cache_plans.append( + _CacheDeletePlan( + tenant_id=self._tenant_id, + identity_id=rewrite.row_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + table_name=LoadBalancingModelConfig.__tablename__, + row_id=rewrite.row_id, + tx_id=tx_id, + business_key=business_key, + ) + ) + return cache_plans + + def _emit_provider_model_credential_group_plan( + self, + plan: _ProviderModelCredentialGroupPlan, + *, + session: Session, + tx_id: str, + business_key: _BusinessKey, + ) -> None: + if plan.winner is None: + return + + loser_credential_ids = [str(row.row.id) for row in plan.loser_rows] + winner_credential_id = str(plan.winner.row.id) + cache_plans: list[_CacheDeletePlan] = [] + cache_plans.extend( + self._emit_provider_model_reference_rewrites( + session, + plan.provider_model_rewrites, + winner_credential_id=winner_credential_id, + loser_credential_ids=loser_credential_ids, + tx_id=tx_id, + business_key=business_key, + ) + ) + cache_plans.extend( + self._emit_load_balancing_reference_rewrites( + session, + plan.load_balancing_rewrites, + winner_credential_id=winner_credential_id, + loser_credential_ids=loser_credential_ids, + tx_id=tx_id, + business_key=business_key, + ) + ) + + for loser in plan.loser_rows: + if self._apply: + session.execute( + sa.delete(ProviderModelCredential).where(ProviderModelCredential.id == str(loser.row.id)) + ) + self._log_row_deleted( + ProviderModelCredential.__tablename__, + loser, + tx_id=tx_id, + business_key=business_key, + related_winner_id=winner_credential_id, + ) + + if plan.winner.raw_model_type != plan.winner.canonical_model_type.value: + if self._apply: + session.execute( + sa.update(ProviderModelCredential) + .where(ProviderModelCredential.id == winner_credential_id) + .values(model_type=plan.winner.canonical_model_type.value) + ) + self._log_row_updated( + ProviderModelCredential.__tablename__, + winner_credential_id, + {"model_type": plan.winner.raw_model_type}, + {"model_type": plan.winner.canonical_model_type.value}, + tx_id=tx_id, + business_key=business_key, + ) + + self._log_cache_plans(cache_plans, apply=self._apply) + self._log_group_processed( + ProviderModelCredential.__tablename__, + business_key, + plan.group_row_ids, + tx_id=tx_id, + ) + + def _process_provider_model_credential_group( + self, + candidate: _RowWithRawModelType[ProviderModelCredential], + business_key: _ProviderModelCredentialBusinessKey, + ) -> list[str]: + tx_id = self._new_tx_id() + group_row_ids = [str(candidate.row.id)] + + try: + with _session_factory(self._engine) as session, session.begin(): + self._configure_lock_timeout(session) + plan = self._build_provider_model_credential_group_plan(session, candidate, lock_rows=True) + group_row_ids = plan.group_row_ids or group_row_ids + self._emit_provider_model_credential_group_plan( + plan, + session=session, + tx_id=tx_id, + business_key=business_key, + ) + except OperationalError as exc: + if self._is_lock_timeout_error(exc): + self._log_lock_timeout( + ProviderModelCredential.__tablename__, + str(candidate.row.id), + tx_id, + business_key, + exc, + ) + return group_row_ids + raise + + return group_row_ids + + def _plan_provider_model_reference_rewrites( + self, + session: Session, + winner: _RowWithRawModelType[ProviderModelCredential], + loser_rows: Sequence[_RowWithRawModelType[ProviderModelCredential]], + *, + lock_rows: bool, + ) -> list[_ProviderModelReferenceRewritePlan]: + loser_ids = [str(row.row.id) for row in loser_rows] + if not loser_ids: + return [] + + stmt = ( + select(ProviderModel) + .where( + ProviderModel.tenant_id == self._tenant_id, + ProviderModel.credential_id.in_(loser_ids), + ) + .order_by(ProviderModel.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + rewrite_plans: list[_ProviderModelReferenceRewritePlan] = [] + provider_models = session.execute(stmt).scalars().all() + for provider_model in provider_models: + rewrite_plans.append( + _ProviderModelReferenceRewritePlan( + row_id=str(provider_model.id), + old_credential_id=str(provider_model.credential_id), + new_credential_id=str(winner.row.id), + ) + ) + return rewrite_plans + + def _plan_load_balancing_reference_rewrites( + self, + session: Session, + winner: _RowWithRawModelType[ProviderModelCredential], + loser_rows: Sequence[_RowWithRawModelType[ProviderModelCredential]], + *, + lock_rows: bool, + ) -> list[_LoadBalancingCredentialRewritePlan]: + loser_ids = [str(row.row.id) for row in loser_rows] + if not loser_ids: + return [] + + stmt = ( + select(LoadBalancingModelConfig) + .where( + LoadBalancingModelConfig.tenant_id == self._tenant_id, + LoadBalancingModelConfig.credential_id.in_(loser_ids), + ) + .order_by(LoadBalancingModelConfig.id.asc()) + ) + if lock_rows: + stmt = stmt.with_for_update() + + winner_credential = winner.row + winner_credential_id = str(winner_credential.id) + winner_credential_name = winner_credential.credential_name + winner_encrypted_config = winner_credential.encrypted_config + + rewrite_plans: list[_LoadBalancingCredentialRewritePlan] = [] + load_balancing_model_configs = session.execute(stmt).scalars().all() + for load_balancing_model_config in load_balancing_model_configs: + rewrite_plans.append( + _LoadBalancingCredentialRewritePlan( + row_id=str(load_balancing_model_config.id), + old_credential_id=load_balancing_model_config.credential_id, + old_name=load_balancing_model_config.name, + old_encrypted_config=load_balancing_model_config.encrypted_config, + new_credential_id=winner_credential_id, + new_name=winner_credential_name, + new_encrypted_config=winner_encrypted_config, + ) + ) + return rewrite_plans + + def _configure_lock_timeout(self, session: Session) -> None: + dialect_name = session.get_bind().dialect.name + if dialect_name == "postgresql": + session.execute(sa.text("SET LOCAL lock_timeout = :timeout"), {"timeout": f"{self._lock_timeout_seconds}s"}) + return + if dialect_name == "mysql": + session.execute( + sa.text("SET SESSION innodb_lock_wait_timeout = :timeout"), + {"timeout": self._lock_timeout_seconds}, + ) + session.execute( + sa.text("SET SESSION lock_wait_timeout = :timeout"), + {"timeout": self._lock_timeout_seconds}, + ) + + def _is_lock_timeout_error(self, exc: OperationalError) -> bool: + orig = exc.orig + structured_string_codes: set[str] = set() + structured_int_codes: set[int] = set() + + if orig is not None: + for raw_code in ( + getattr(orig, "sqlstate", None), + getattr(orig, "pgcode", None), + getattr(orig, "code", None), + getattr(orig, "errno", None), + ): + normalized_string_code = _normalize_error_code_string(raw_code) + if normalized_string_code is not None: + structured_string_codes.add(normalized_string_code) + + normalized_int_code = _normalize_error_code_int(raw_code) + if normalized_int_code is not None: + structured_int_codes.add(normalized_int_code) + + raw_args = getattr(orig, "args", None) + if isinstance(raw_args, tuple | list) and raw_args: + first_arg = raw_args[0] + normalized_string_code = _normalize_error_code_string(first_arg) + if normalized_string_code is not None: + structured_string_codes.add(normalized_string_code) + + normalized_int_code = _normalize_error_code_int(first_arg) + if normalized_int_code is not None: + structured_int_codes.add(normalized_int_code) + + if structured_string_codes & _POSTGRES_LOCK_TIMEOUT_SQLSTATES: + return True + if structured_int_codes & _MYSQL_LOCK_TIMEOUT_ERRNOS: + return True + + error_message = str(orig if orig is not None else exc).lower() + return any(message in error_message for message in _LOCK_TIMEOUT_FALLBACK_MESSAGES) + + def _log_lock_timeout( + self, + table_name: str, + row_id: str, + tx_id: str, + business_key: _BusinessKey | None, + exc: OperationalError, + ) -> None: + attrs: dict[str, object] = { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "id": row_id, + "tx_id": tx_id, + } + if business_key is not None: + attrs["business_key"] = self._business_key_to_dict(business_key) + self._log_exception_event( + "lock_timeout_skipped", + "Skipped transaction because row lock timed out.", + attrs, + exc, + ) + + def _business_key_to_dict(self, business_key: _BusinessKey) -> dict[str, object]: + return cast(dict[str, object], asdict(business_key)) + + def _row_to_dict(self, row: TypeBase, *, raw_model_type: str | None = None) -> dict[str, object]: + mapper = sa.inspect(row).mapper + row_dict = {column.key: row.__dict__[column.key] for column in mapper.column_attrs} + if raw_model_type is not None and "model_type" in row_dict: + row_dict["model_type"] = raw_model_type + return _normalize_log_mapping(row_dict) + + def _log_row_deleted[T: TypeBase]( + self, + table_name: str, + row: _RowWithRawModelType[T], + *, + tx_id: str, + business_key: _BusinessKey, + related_winner_id: str, + ) -> None: + self._log_event( + "row_deleted", + "Deleted loser row during canonicalization.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "id": self._row_id(row.row), + "tx_id": tx_id, + "business_key": self._business_key_to_dict(business_key), + "merge_winner_id": related_winner_id, + "row": self._row_to_dict(row.row, raw_model_type=row.raw_model_type), + }, + ) + + def _log_row_updated( + self, + table_name: str, + row_id: str, + old_values: dict[str, object], + new_values: dict[str, object], + *, + tx_id: str, + business_key: _BusinessKey | None = None, + rewrite_source: dict[str, object] | None = None, + ) -> None: + attrs: dict[str, object] = { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "id": row_id, + "tx_id": tx_id, + "old_values": _normalize_log_mapping(old_values), + "new_values": _normalize_log_mapping(new_values), + } + if business_key is not None: + attrs["business_key"] = self._business_key_to_dict(business_key) + if rewrite_source is not None: + attrs["rewrite_source"] = rewrite_source + self._log_event("row_updated", "Updated row values during canonicalization.", attrs) + + def _log_group_processed( + self, + table_name: str, + business_key: _BusinessKey, + group_row_ids: Sequence[str], + *, + tx_id: str, + ) -> None: + self._log_event( + "group_processed", + "Processed business-key group during canonicalization.", + { + "tenant_id": self._tenant_id, + "apply": self._apply, + "table_name": table_name, + "business_key": self._business_key_to_dict(business_key), + "group_row_ids": list(group_row_ids), + "tx_id": tx_id, + }, + ) + + def _log_cache_plans(self, cache_plans: Iterable[_CacheDeletePlan], *, apply: bool) -> None: + for cache_plan in cache_plans: + if apply: + try: + ProviderCredentialsCache( + tenant_id=cache_plan.tenant_id, + identity_id=cache_plan.identity_id, + cache_type=cache_plan.cache_type, + ).delete() + self._log_event( + "cache_deleted", + "Deleted related cache entry.", + { + "tenant_id": cache_plan.tenant_id, + "apply": apply, + "table_name": cache_plan.table_name, + "id": cache_plan.row_id, + "cache_type": cache_plan.cache_type.value, + "tx_id": cache_plan.tx_id, + "business_key": self._business_key_to_dict(cache_plan.business_key), + }, + ) + except Exception as exc: + self._log_exception_event( + "cache_delete_failed", + "Failed to delete related cache entry.", + { + "tenant_id": cache_plan.tenant_id, + "apply": apply, + "table_name": cache_plan.table_name, + "id": cache_plan.row_id, + "cache_type": cache_plan.cache_type.value, + "tx_id": cache_plan.tx_id, + "business_key": self._business_key_to_dict(cache_plan.business_key), + }, + exc, + ) + else: + self._log_event( + "cache_delete_planned", + "Would delete related cache entry in apply mode.", + { + "tenant_id": cache_plan.tenant_id, + "apply": apply, + "table_name": cache_plan.table_name, + "id": cache_plan.row_id, + "cache_type": cache_plan.cache_type.value, + "tx_id": cache_plan.tx_id, + "business_key": self._business_key_to_dict(cache_plan.business_key), + }, + ) + + def _log_exception_event( + self, + event: str, + message: str, + attrs: dict[str, object], + exc: BaseException, + ) -> None: + self._log_event( + event, + message, + { + **attrs, + "error": str(exc), + "stacktrace": _format_exception_stacktrace(exc), + }, + ) + + def _log_event(self, event: str, message: str, attrs: dict[str, object]) -> None: + record = { + "event": event, + "message": message, + "attrs": _normalize_log_payload(attrs), + "ts": naive_utc_now().isoformat(), + } + print(json.dumps(record, default=_json_default), file=self._output, flush=True) + + +def load_tenant_ids_from_file(path: str) -> list[str]: + """ + Load tenant ids from a plain-text file, one tenant id per line. + """ + + tenant_ids: list[str] = [] + seen_tenant_ids: set[str] = set() + with open(path, encoding="utf-8") as file: + for raw_line in file: + tenant_id = raw_line.strip() + if not tenant_id or tenant_id in seen_tenant_ids: + continue + seen_tenant_ids.add(tenant_id) + tenant_ids.append(tenant_id) + return tenant_ids diff --git a/api/tests/helpers/__init__.py b/api/tests/helpers/__init__.py new file mode 100644 index 0000000000..5183591f40 --- /dev/null +++ b/api/tests/helpers/__init__.py @@ -0,0 +1 @@ +"""Shared test helpers for backend migration tests.""" diff --git a/api/tests/helpers/legacy_model_type_migration.py b/api/tests/helpers/legacy_model_type_migration.py new file mode 100644 index 0000000000..12f092a0fe --- /dev/null +++ b/api/tests/helpers/legacy_model_type_migration.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta +from uuid import uuid4 + +import sqlalchemy as sa +from sqlalchemy.engine import Engine + +from models.account import Tenant +from models.enums import CredentialSourceType +from models.provider import ( + LoadBalancingModelConfig, + ProviderModel, + ProviderModelCredential, + ProviderModelSetting, + TenantDefaultModel, +) + +LEGACY_TO_CANONICAL: dict[str, str] = { + "text-generation": "llm", + "embeddings": "text-embedding", + "reranking": "rerank", +} +UNCHANGED_MODEL_TYPES: tuple[str, ...] = ("speech2text", "moderation", "tts") +ALL_TABLE_NAMES: tuple[str, ...] = ( + ProviderModel.__tablename__, + TenantDefaultModel.__tablename__, + ProviderModelSetting.__tablename__, + LoadBalancingModelConfig.__tablename__, + ProviderModelCredential.__tablename__, +) +DEFAULT_PRIMARY_TENANT_ID = "00000000-0000-0000-0000-000000000101" +DEFAULT_SECONDARY_TENANT_ID = "00000000-0000-0000-0000-000000000202" + + +@dataclass(frozen=True, slots=True) +class DirtyTenantFixture: + tenant_id: str + winner_credential_id: str + loser_credential_id: str + distinct_credential_id: str + provider_model_id: str + load_balancing_config_id: str + provider_model_setting_id: str + tenant_default_model_id: str + embedding_provider_model_id: str + embedding_setting_id: str + loser_credential_name: str + distinct_credential_name: str + loser_encrypted_config: str + winner_encrypted_config: str + + +@dataclass(frozen=True, slots=True) +class DirtyDataFixture: + primary: DirtyTenantFixture + secondary: DirtyTenantFixture + + +def create_minimal_legacy_model_type_schema(engine: Engine) -> None: + metadata = Tenant.__table__.metadata + metadata.create_all( + engine, + tables=[ + Tenant.__table__, + ProviderModel.__table__, + TenantDefaultModel.__table__, + ProviderModelSetting.__table__, + LoadBalancingModelConfig.__table__, + ProviderModelCredential.__table__, + ], + checkfirst=True, + ) + + +def drop_minimal_legacy_model_type_schema(engine: Engine) -> None: + metadata = Tenant.__table__.metadata + metadata.drop_all( + engine, + tables=[ + LoadBalancingModelConfig.__table__, + ProviderModelSetting.__table__, + TenantDefaultModel.__table__, + ProviderModel.__table__, + ProviderModelCredential.__table__, + Tenant.__table__, + ], + checkfirst=True, + ) + + +def seed_legacy_model_type_dirty_data( + engine: Engine, + *, + primary_tenant_id: str = DEFAULT_PRIMARY_TENANT_ID, + secondary_tenant_id: str = DEFAULT_SECONDARY_TENANT_ID, +) -> DirtyDataFixture: + create_minimal_legacy_model_type_schema(engine) + primary = _seed_tenant(engine, tenant_id=primary_tenant_id, provider_name="openai") + secondary = _seed_tenant(engine, tenant_id=secondary_tenant_id, provider_name="openai") + return DirtyDataFixture(primary=primary, secondary=secondary) + + +def snapshot_legacy_model_type_state(engine: Engine) -> dict[str, list[dict[str, object]]]: + snapshots: dict[str, list[dict[str, object]]] = {} + for table_name in ALL_TABLE_NAMES: + snapshots[table_name] = fetch_table_rows(engine, table_name) + return snapshots + + +def fetch_table_rows( + engine: Engine, + table_name: str, + *, + tenant_id: str | None = None, +) -> list[dict[str, object]]: + sql = f"SELECT * FROM {table_name}" + params: dict[str, object] = {} + if tenant_id is not None: + sql += " WHERE tenant_id = :tenant_id" + params["tenant_id"] = tenant_id + sql += " ORDER BY id ASC" + + with engine.begin() as conn: + rows = conn.execute(sa.text(sql), params).mappings().all() + + result: list[dict[str, object]] = [] + for row in rows: + normalized = dict(row) + for key, value in normalized.items(): + if isinstance(value, datetime): + normalized[key] = value.isoformat() + elif isinstance(value, uuid.UUID): + normalized[key] = str(value) + result.append(normalized) + return result + + +def fetch_model_types_for_tenant(engine: Engine, table_name: str, tenant_id: str) -> list[str]: + rows = fetch_table_rows(engine, table_name, tenant_id=tenant_id) + return [str(row["model_type"]) for row in rows] + + +def assert_tenant_rows_use_only_canonical_model_types(engine: Engine, tenant_id: str) -> None: + for table_name in ALL_TABLE_NAMES: + model_types = fetch_model_types_for_tenant(engine, table_name, tenant_id) + assert set(model_types) <= set(LEGACY_TO_CANONICAL.values()) | set(UNCHANGED_MODEL_TYPES), ( + table_name, + model_types, + ) + + +def count_rows(engine: Engine, table_name: str, *, tenant_id: str) -> int: + with engine.begin() as conn: + stmt = sa.text(f"SELECT COUNT(*) FROM {table_name} WHERE tenant_id = :tenant_id") + return int(conn.execute(stmt, {"tenant_id": tenant_id}).scalar_one()) + + +def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> DirtyTenantFixture: + now = datetime(2025, 1, 1, 12, 0, 0) + winner_credential_id = str(uuid4()) + loser_credential_id = str(uuid4()) + distinct_credential_id = str(uuid4()) + provider_model_id = str(uuid4()) + load_balancing_config_id = str(uuid4()) + provider_model_setting_id = str(uuid4()) + tenant_default_model_id = str(uuid4()) + embedding_provider_model_id = str(uuid4()) + embedding_setting_id = str(uuid4()) + + loser_credential_name = f"{tenant_id}-shared" + distinct_credential_name = f"{tenant_id}-distinct" + winner_encrypted_config = json.dumps({"api_key": f"{tenant_id}-winner"}) + loser_encrypted_config = json.dumps({"api_key": f"{tenant_id}-loser"}) + distinct_encrypted_config = json.dumps({"api_key": f"{tenant_id}-distinct"}) + + with engine.begin() as conn: + conn.execute( + Tenant.__table__.insert().values( + id=tenant_id, + name=f"Tenant {tenant_id}", + plan="basic", + status="normal", + ) + ) + conn.execute( + sa.text( + """ + INSERT INTO provider_model_credentials + ( + id, tenant_id, provider_name, model_name, + model_type, credential_name, encrypted_config, + created_at, updated_at + ) + VALUES + ( + :winner_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'llm', :shared_name, :winner_config, + :created_at, :winner_updated_at + ), + ( + :loser_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :shared_name, :loser_config, + :created_at, :loser_updated_at + ), + ( + :distinct_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :distinct_name, :distinct_config, + :created_at, :distinct_updated_at + ) + """ + ), + { + "winner_id": winner_credential_id, + "loser_id": loser_credential_id, + "distinct_id": distinct_credential_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "shared_name": loser_credential_name, + "distinct_name": distinct_credential_name, + "winner_config": winner_encrypted_config, + "loser_config": loser_encrypted_config, + "distinct_config": distinct_encrypted_config, + "created_at": now - timedelta(days=2), + "winner_updated_at": now, + "loser_updated_at": now - timedelta(days=1), + "distinct_updated_at": now - timedelta(hours=12), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO provider_models + ( + id, tenant_id, provider_name, model_name, + model_type, credential_id, is_valid, + created_at, updated_at + ) + VALUES + ( + :provider_model_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :loser_id, :is_valid, + :created_at, :updated_at + ), + ( + :embedding_provider_model_id, :tenant_id, :provider_name, 'text-embedding-3-large', + 'embeddings', NULL, :is_valid, + :created_at, :updated_at + ) + """ + ), + { + "provider_model_id": provider_model_id, + "embedding_provider_model_id": embedding_provider_model_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "loser_id": loser_credential_id, + "is_valid": True, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=6), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO tenant_default_models + (id, tenant_id, provider_name, model_name, model_type, created_at, updated_at) + VALUES + ( + :tenant_default_model_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :created_at, :updated_at + ) + """ + ), + { + "tenant_default_model_id": tenant_default_model_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=4), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO provider_model_settings + ( + id, tenant_id, provider_name, model_name, + model_type, enabled, load_balancing_enabled, + created_at, updated_at + ) + VALUES + ( + :provider_model_setting_id, :tenant_id, :provider_name, 'gpt-4o-mini', + 'text-generation', :enabled, :load_balancing_enabled, + :created_at, :updated_at + ), + ( + :embedding_setting_id, :tenant_id, :provider_name, 'text-embedding-3-large', + 'embeddings', :enabled, :embedding_load_balancing_enabled, + :created_at, :updated_at + ) + """ + ), + { + "provider_model_setting_id": provider_model_setting_id, + "embedding_setting_id": embedding_setting_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "enabled": True, + "load_balancing_enabled": True, + "embedding_load_balancing_enabled": False, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=3), + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO load_balancing_model_configs + ( + id, tenant_id, provider_name, model_name, model_type, + name, encrypted_config, credential_id, credential_source_type, + enabled, created_at, updated_at + ) + VALUES + ( + :load_balancing_config_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation', + :lb_name, :loser_config, :loser_id, :credential_source_type, + :enabled, :created_at, :updated_at + ) + """ + ), + { + "load_balancing_config_id": load_balancing_config_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "lb_name": loser_credential_name, + "loser_config": loser_encrypted_config, + "loser_id": loser_credential_id, + "credential_source_type": CredentialSourceType.CUSTOM_MODEL.value, + "enabled": True, + "created_at": now - timedelta(days=2), + "updated_at": now - timedelta(hours=2), + }, + ) + + return DirtyTenantFixture( + tenant_id=tenant_id, + winner_credential_id=winner_credential_id, + loser_credential_id=loser_credential_id, + distinct_credential_id=distinct_credential_id, + provider_model_id=provider_model_id, + load_balancing_config_id=load_balancing_config_id, + provider_model_setting_id=provider_model_setting_id, + tenant_default_model_id=tenant_default_model_id, + embedding_provider_model_id=embedding_provider_model_id, + embedding_setting_id=embedding_setting_id, + loser_credential_name=loser_credential_name, + distinct_credential_name=distinct_credential_name, + loser_encrypted_config=loser_encrypted_config, + winner_encrypted_config=winner_encrypted_config, + ) diff --git a/api/tests/seed_legacy_model_type_dirty_data.py b/api/tests/seed_legacy_model_type_dirty_data.py new file mode 100644 index 0000000000..c860cea956 --- /dev/null +++ b/api/tests/seed_legacy_model_type_dirty_data.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +API_PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(API_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(API_PROJECT_ROOT)) + +import sqlalchemy as sa + +from tests.helpers.legacy_model_type_migration import ( + DEFAULT_PRIMARY_TENANT_ID, + DEFAULT_SECONDARY_TENANT_ID, + create_minimal_legacy_model_type_schema, + seed_legacy_model_type_dirty_data, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Seed dirty legacy model_type rows for manual migration experiments. " + "Example: uv run --project api python api/tests/seed_legacy_model_type_dirty_data.py " + "--db-url postgresql://postgres:postgres@127.0.0.1:5432/dify" + ) + ) + parser.add_argument("--db-url", required=True, help="SQLAlchemy database URL for the target database.") + parser.add_argument( + "--primary-tenant-id", + default=DEFAULT_PRIMARY_TENANT_ID, + help="Tenant that will contain the main conflict scenario.", + ) + parser.add_argument( + "--secondary-tenant-id", + default=DEFAULT_SECONDARY_TENANT_ID, + help="Tenant used to verify tenant filtering behavior.", + ) + parser.add_argument( + "--create-minimal-schema", + action="store_true", + help="Create the minimal tables needed for the seed when running against an empty scratch database.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + engine = sa.create_engine(args.db_url) + try: + if args.create_minimal_schema: + create_minimal_legacy_model_type_schema(engine) + + fixture = seed_legacy_model_type_dirty_data( + engine, + primary_tenant_id=args.primary_tenant_id, + secondary_tenant_id=args.secondary_tenant_id, + ) + finally: + engine.dispose() + + print( + json.dumps( + { + "primary_tenant_id": fixture.primary.tenant_id, + "secondary_tenant_id": fixture.secondary.tenant_id, + "winner_credential_id": fixture.primary.winner_credential_id, + "loser_credential_id": fixture.primary.loser_credential_id, + "provider_model_id": fixture.primary.provider_model_id, + "load_balancing_config_id": fixture.primary.load_balancing_config_id, + }, + indent=2, + sort_keys=True, + ) + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py b/api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py new file mode 100644 index 0000000000..401696d5ca --- /dev/null +++ b/api/tests/test_containers_integration_tests/commands/test_legacy_model_type_migration.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import importlib +import io +import json +from collections.abc import Generator +from datetime import datetime, timedelta + +import pytest +import sqlalchemy as sa + +from tests.helpers.legacy_model_type_migration import ( + assert_tenant_rows_use_only_canonical_model_types, + count_rows, + fetch_table_rows, + seed_legacy_model_type_dirty_data, +) + + +def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]: + return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()] + + +def _json_key(value: object) -> str: + return json.dumps(value, sort_keys=True) + + +def _lb_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]: + signatures: set[tuple[object, ...]] = set() + for line in lines: + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + if attrs.get("table_name") != "load_balancing_model_configs": + continue + event = line.get("event") + if event == "row_updated": + signatures.add( + ( + event, + attrs.get("id"), + _json_key(attrs.get("old_values")), + _json_key(attrs.get("new_values")), + ) + ) + elif event == "row_deleted": + signatures.add( + ( + event, + attrs.get("id"), + attrs.get("merge_winner_id"), + ) + ) + elif event == "group_processed": + signatures.add( + ( + event, + attrs.get("table_name"), + _json_key(attrs.get("business_key")), + tuple(attrs.get("group_row_ids", [])), + ) + ) + return signatures + + +def _insert_load_balancing_model_config( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + name: str, + encrypted_config: str, + credential_id: str, + enabled: bool, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO load_balancing_model_configs + ( + id, tenant_id, provider_name, model_name, model_type, name, + encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, :name, + :encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "name": name, + "encrypted_config": encrypted_config, + "credential_id": credential_id, + "credential_source_type": "custom_model", + "enabled": enabled, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +@pytest.fixture(scope="session") +def migration_module(): + try: + return importlib.import_module("services.legacy_model_type_migration") + except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path + pytest.fail( + "services.legacy_model_type_migration is missing. " + "Implement LegacyModelTypeMigrationService before running these tests." + ) + + +@pytest.fixture(params=("postgresql", "mysql"), scope="session") +def container_engine(request: pytest.FixtureRequest) -> Generator[tuple[str, sa.Engine], None, None]: + backend_name = request.param + if backend_name == "postgresql": + testcontainers_postgres = pytest.importorskip("testcontainers.postgres") + container = testcontainers_postgres.PostgresContainer("postgres:15-alpine") + else: + testcontainers_mysql = pytest.importorskip("testcontainers.mysql") + container = testcontainers_mysql.MySqlContainer("mysql:8.0") + + container.start() + raw_url = container.get_connection_url() + engine_url = raw_url.replace("mysql://", "mysql+pymysql://", 1) + engine = sa.create_engine(engine_url) + + try: + yield backend_name, engine + finally: + engine.dispose() + container.stop() + + +def test_legacy_model_type_migration_end_to_end_across_supported_backends( + migration_module, + container_engine: tuple[str, sa.Engine], + monkeypatch: pytest.MonkeyPatch, +) -> None: + backend_name, engine = container_engine + helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration") + helper_module.drop_minimal_legacy_model_type_schema(engine) + fixture = seed_legacy_model_type_dirty_data(engine) + + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=False, + output=dry_run_output, + tenant_ids=(fixture.primary.tenant_id,), + ).migrate() + + assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 3 + assert deleted_cache_keys == [] + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=True, + output=apply_output, + tenant_ids=(fixture.primary.tenant_id,), + ).migrate() + first_apply_state = { + table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id) + for table_name in ( + "provider_models", + "tenant_default_models", + "provider_model_settings", + "load_balancing_model_configs", + "provider_model_credentials", + ) + } + + assert_tenant_rows_use_only_canonical_model_types(engine, fixture.primary.tenant_id) + assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 2 + provider_model_row = next( + row for row in first_apply_state["provider_models"] if row["id"] == fixture.primary.provider_model_id + ) + assert provider_model_row["credential_id"] == fixture.primary.winner_credential_id + credential_ids = {str(row["id"]) for row in first_apply_state["provider_model_credentials"]} + assert credential_ids == { + fixture.primary.winner_credential_id, + fixture.primary.distinct_credential_id, + } + lb_row = next( + row + for row in first_apply_state["load_balancing_model_configs"] + if row["id"] == fixture.primary.load_balancing_config_id + ) + assert lb_row["credential_id"] == fixture.primary.winner_credential_id + assert lb_row["encrypted_config"] == fixture.primary.winner_encrypted_config + assert deleted_cache_keys, f"{backend_name} apply run should clear cache keys" + + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=True, + output=io.StringIO(), + tenant_ids=(fixture.primary.tenant_id,), + ).migrate() + second_apply_state = { + table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id) + for table_name in first_apply_state + } + assert second_apply_state == first_apply_state + + +def test_load_balancing_inherit_deduplication_is_applied_consistently_across_supported_backends( + migration_module, + container_engine: tuple[str, sa.Engine], + monkeypatch: pytest.MonkeyPatch, +) -> None: + _, engine = container_engine + helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration") + helper_module.drop_minimal_legacy_model_type_schema(engine) + fixture = seed_legacy_model_type_dirty_data(engine) + + tenant_id = fixture.primary.tenant_id + older_inherit_row_id = "00000000-0000-0000-0000-00000000ee01" + newer_inherit_row_id = "00000000-0000-0000-0000-00000000ee02" + canonical_non_inherit_row_id = "00000000-0000-0000-0000-00000000ee03" + created_at = datetime(2025, 1, 1, 8, 0, 0) + + _insert_load_balancing_model_config( + engine, + row_id=older_inherit_row_id, + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="llm", + name="__inherit__", + encrypted_config='{"api_key":"older-inherit"}', + credential_id=fixture.primary.winner_credential_id, + enabled=True, + created_at=created_at, + updated_at=created_at + timedelta(minutes=15), + ) + _insert_load_balancing_model_config( + engine, + row_id=newer_inherit_row_id, + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + name="__inherit__", + encrypted_config='{"api_key":"newer-inherit"}', + credential_id=fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=created_at + timedelta(minutes=30), + ) + _insert_load_balancing_model_config( + engine, + row_id=canonical_non_inherit_row_id, + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="llm", + name=f"{tenant_id}-second-shared", + encrypted_config='{"api_key":"non-inherit-canonical"}', + credential_id=fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=created_at + timedelta(minutes=45), + ) + + before_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=False, + output=dry_run_output, + tables=("load_balancing_model_configs",), + model_types=(migration_module.ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + after_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) + dry_run_lines = _parse_json_lines(dry_run_output) + dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")] + dry_run_row_updates = { + str(attrs["id"]) + for line in dry_run_lines + if line.get("event") == "row_updated" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + } + dry_run_row_deletes = { + str(attrs["id"]) + for line in dry_run_lines + if line.get("event") == "row_deleted" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + } + dry_run_group_processed = [ + attrs + for line in dry_run_lines + if line.get("event") == "group_processed" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + ] + + assert after_dry_run == before_dry_run + assert deleted_cache_keys == [] + assert dry_run_row_deletes == {older_inherit_row_id} + assert dry_run_row_updates == { + fixture.primary.load_balancing_config_id, + newer_inherit_row_id, + } + assert canonical_non_inherit_row_id not in dry_run_row_updates + assert "cache_delete_planned" in dry_run_cache_events + assert "cache_deleted" not in dry_run_cache_events + assert len(dry_run_group_processed) == 1 + assert dry_run_group_processed[0]["table_name"] == "load_balancing_model_configs" + assert dry_run_group_processed[0]["business_key"] == { + "tenant_id": tenant_id, + "provider_name": "openai", + "model_name": "gpt-4o-mini", + "model_type": "llm", + } + assert set(dry_run_group_processed[0]["group_row_ids"]) == { + older_inherit_row_id, + newer_inherit_row_id, + } + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=engine, + apply=True, + output=apply_output, + tables=("load_balancing_model_configs",), + model_types=(migration_module.ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + apply_lines = _parse_json_lines(apply_output) + apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")] + apply_group_processed = [ + attrs + for line in apply_lines + if line.get("event") == "group_processed" + and isinstance((attrs := line.get("attrs")), dict) + and attrs.get("table_name") == "load_balancing_model_configs" + ] + assert _lb_processing_signatures(apply_lines) == _lb_processing_signatures(dry_run_lines) + assert "cache_deleted" in apply_cache_events + assert deleted_cache_keys + assert len(apply_group_processed) == len(dry_run_group_processed) + assert [ + ( + attrs["table_name"], + _json_key(attrs["business_key"]), + tuple(attrs["group_row_ids"]), + ) + for attrs in apply_group_processed + ] == [ + ( + attrs["table_name"], + _json_key(attrs["business_key"]), + tuple(attrs["group_row_ids"]), + ) + for attrs in dry_run_group_processed + ] + + lb_rows = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) + surviving_inherit_rows = [row for row in lb_rows if row["name"] == "__inherit__"] + surviving_non_inherit_rows = [row for row in lb_rows if row["name"] != "__inherit__"] + + assert {str(row["id"]) for row in surviving_inherit_rows} == {newer_inherit_row_id} + assert surviving_inherit_rows[0]["model_type"] == "llm" + assert surviving_inherit_rows[0]["credential_id"] == fixture.primary.distinct_credential_id + + assert { + str(row["id"]) + for row in surviving_non_inherit_rows + if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id} + } == {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id} + assert all( + row["model_type"] == "llm" + for row in surviving_non_inherit_rows + if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id} + ) + assert count_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) == len(before_dry_run) - 1 diff --git a/api/tests/unit_tests/commands/test_legacy_model_type_migration.py b/api/tests/unit_tests/commands/test_legacy_model_type_migration.py new file mode 100644 index 0000000000..7eead948c1 --- /dev/null +++ b/api/tests/unit_tests/commands/test_legacy_model_type_migration.py @@ -0,0 +1,2025 @@ +from __future__ import annotations + +import importlib +import io +import json +import os +import threading +import time +from datetime import datetime, timedelta +from pathlib import Path +from types import SimpleNamespace +from typing import cast + +import pytest +import sqlalchemy as sa +from click.testing import CliRunner +from sqlalchemy.exc import OperationalError + +from graphon.model_runtime.entities.model_entities import ModelType +from models.account import Tenant +from models.enums import CredentialSourceType +from models.provider import ProviderModel +from tests.helpers.legacy_model_type_migration import ( + ALL_TABLE_NAMES, + LEGACY_TO_CANONICAL, + assert_tenant_rows_use_only_canonical_model_types, + count_rows, + create_minimal_legacy_model_type_schema, + fetch_table_rows, + seed_legacy_model_type_dirty_data, + snapshot_legacy_model_type_state, +) + + +@pytest.fixture +def sqlite_engine(tmp_path: Path) -> sa.Engine: + engine = sa.create_engine(f"sqlite:///{tmp_path / 'legacy_model_type_migration.sqlite'}") + try: + yield engine + finally: + engine.dispose() + + +@pytest.fixture +def dirty_fixture(sqlite_engine: sa.Engine): + return seed_legacy_model_type_dirty_data(sqlite_engine) + + +@pytest.fixture +def migration_module(): + try: + return importlib.import_module("services.legacy_model_type_migration") + except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path + pytest.fail( + "services.legacy_model_type_migration is missing. " + "Implement LegacyModelTypeMigrationService before running these tests." + ) + + +@pytest.fixture +def command_module(): + try: + return importlib.import_module("commands.data_migrate") + except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path + pytest.fail( + "commands.data_migrate is missing. " + "Implement the `flask data-migrate legacy-model-types` command group before running these tests." + ) + + +def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]: + return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()] + + +def _json_key(value: object) -> str: + return json.dumps(value, sort_keys=True) + + +def _event_signature(line: dict[str, object]) -> tuple[object, ...] | None: + event = line.get("event") + attrs = line.get("attrs") + if not isinstance(attrs, dict): + return None + + if event == "row_updated": + return ( + event, + attrs.get("table_name"), + attrs.get("id"), + _json_key(attrs.get("business_key")), + _json_key(attrs.get("old_values")), + _json_key(attrs.get("new_values")), + _json_key(attrs.get("rewrite_source")), + ) + if event == "row_deleted": + return ( + event, + attrs.get("table_name"), + attrs.get("id"), + _json_key(attrs.get("business_key")), + attrs.get("merge_winner_id"), + ) + if event == "group_processed": + return ( + event, + attrs.get("table_name"), + _json_key(attrs.get("business_key")), + tuple(cast(list[str], attrs.get("group_row_ids", []))), + ) + return None + + +def _collect_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]: + signatures: set[tuple[object, ...]] = set() + for line in lines: + signature = _event_signature(line) + if signature is not None: + signatures.add(signature) + return signatures + + +def _cache_event_row_ids( + lines: list[dict[str, object]], + *, + table_name: str, + row_ids: set[str], + event_name: str, +) -> set[str]: + matching_row_ids: set[str] = set() + for line in lines: + if line.get("event") != event_name: + continue + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + if attrs.get("table_name") != table_name: + continue + row_id = str(attrs.get("id")) + if row_id in row_ids: + matching_row_ids.add(row_id) + return matching_row_ids + + +def _patch_batch_size( + monkeypatch: pytest.MonkeyPatch, + migration_module, + *, + batch_size: int, +) -> None: + original_init = migration_module.Migration.__init__ + + def _patched_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + self._batch_size = batch_size + + monkeypatch.setattr(migration_module.Migration, "__init__", _patched_init) + + +def _insert_provider_model( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + credential_id: str | None, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO provider_models + ( + id, tenant_id, provider_name, model_name, model_type, + credential_id, is_valid, created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, + :credential_id, :is_valid, :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "credential_id": credential_id, + "is_valid": True, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def _insert_tenant(engine: sa.Engine, *, tenant_id: str) -> None: + with engine.begin() as conn: + conn.execute( + Tenant.__table__.insert().values( + id=tenant_id, + name=f"Tenant {tenant_id}", + plan="basic", + status="normal", + ) + ) + + +def _insert_tenant_default_model( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO tenant_default_models + (id, tenant_id, provider_name, model_name, model_type, created_at, updated_at) + VALUES + (:id, :tenant_id, :provider_name, :model_name, :model_type, :created_at, :updated_at) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def _insert_provider_model_setting( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + enabled: bool, + load_balancing_enabled: bool, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO provider_model_settings + ( + id, tenant_id, provider_name, model_name, model_type, + enabled, load_balancing_enabled, + created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, + :enabled, :load_balancing_enabled, + :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "enabled": enabled, + "load_balancing_enabled": load_balancing_enabled, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def _insert_load_balancing_model_config( + engine: sa.Engine, + *, + row_id: str, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + name: str, + encrypted_config: str, + credential_id: str, + enabled: bool, + created_at: datetime, + updated_at: datetime, +) -> None: + with engine.begin() as conn: + conn.execute( + sa.text( + """ + INSERT INTO load_balancing_model_configs + ( + id, tenant_id, provider_name, model_name, model_type, name, + encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at + ) + VALUES + ( + :id, :tenant_id, :provider_name, :model_name, :model_type, :name, + :encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at + ) + """ + ), + { + "id": row_id, + "tenant_id": tenant_id, + "provider_name": provider_name, + "model_name": model_name, + "model_type": model_type, + "name": name, + "encrypted_config": encrypted_config, + "credential_id": credential_id, + "credential_source_type": CredentialSourceType.CUSTOM_MODEL.value, + "enabled": enabled, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + + +def test_data_migrate_command_defaults_output_to_stdout_stream( + command_module, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + service_calls: list[dict[str, object]] = [] + fake_stdout = io.StringIO() + + class FakeService: + def __init__( + self, + *, + engine: sa.Engine, + apply: bool, + concurrency: int, + output: io.TextIOBase | None = None, + tables: tuple[str, ...] | None, + model_types: tuple[ModelType, ...], + tenant_ids: tuple[str, ...] | None, + ) -> None: + service_calls.append( + { + "engine": engine, + "apply": apply, + "concurrency": concurrency, + "output": output, + "tables": tables, + "model_types": model_types, + "tenant_ids": tenant_ids, + } + ) + + def migrate(self) -> None: + service_calls.append({"migrated": True}) + + monkeypatch.setattr(command_module, "LegacyModelTypeMigrationService", FakeService) + monkeypatch.setattr(command_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(command_module.sys, "stdout", fake_stdout) + tenant_id_file = tmp_path / "tenant_ids.txt" + tenant_id_file.write_text("tenant-alpha\n", encoding="utf-8") + + data_migrate = command_module.data_migrate + legacy_model_types = cast(object, data_migrate.commands["legacy-model-types"]) + + legacy_model_types.callback( + apply=True, + tables=("provider_models",), + model_types=("llm", "text-embedding"), + tenant_id_file=str(tenant_id_file), + output=None, + concurrency=7, + ) + + assert service_calls[0]["apply"] is True + assert service_calls[0]["concurrency"] == 7 + assert service_calls[0]["output"] is fake_stdout + assert service_calls[0]["tables"] == ("provider_models",) + assert tuple(cast(list[str], service_calls[0]["tenant_ids"])) == ("tenant-alpha",) + assert service_calls[0]["model_types"] == (ModelType.LLM, ModelType.TEXT_EMBEDDING) + assert service_calls[1] == {"migrated": True} + + +def test_data_migrate_command_opens_output_file_and_closes_stream( + command_module, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + service_calls: list[dict[str, object]] = [] + + class FakeService: + def __init__( + self, + *, + engine: sa.Engine, + apply: bool, + concurrency: int, + output: io.TextIOBase | None = None, + tables: tuple[str, ...] | None, + model_types: tuple[ModelType, ...], + tenant_ids: tuple[str, ...] | None, + ) -> None: + service_calls.append( + { + "engine": engine, + "apply": apply, + "concurrency": concurrency, + "output": output, + "tables": tables, + "model_types": model_types, + "tenant_ids": tenant_ids, + } + ) + + def migrate(self) -> None: + output = cast(io.TextIOBase, service_calls[0]["output"]) + output.write('{"event":"test"}\n') + service_calls.append({"migrated": True}) + + monkeypatch.setattr(command_module, "LegacyModelTypeMigrationService", FakeService) + monkeypatch.setattr(command_module, "db", SimpleNamespace(engine=object())) + output_path = tmp_path / "migration.jsonl" + + data_migrate = command_module.data_migrate + legacy_model_types = cast(object, data_migrate.commands["legacy-model-types"]) + + legacy_model_types.callback( + apply=False, + tables=(), + model_types=(), + tenant_id_file=None, + output=output_path, + concurrency=3, + ) + + output_stream = cast(io.TextIOBase, service_calls[0]["output"]) + assert service_calls[0]["concurrency"] == 3 + assert output_stream is not output_path + assert isinstance(output_stream, io.TextIOBase) + assert Path(output_stream.name) == output_path + assert output_stream.closed is True + assert output_path.read_text(encoding="utf-8") == '{"event":"test"}\n' + assert service_calls[1] == {"migrated": True} + + +@pytest.mark.parametrize( + ("cpu_count", "expected_concurrency"), + [ + (8, 8), + (None, 1), + ], +) +def test_data_migrate_command_defaults_concurrency_from_cpu_count_or_falls_back_to_one( + monkeypatch: pytest.MonkeyPatch, + cpu_count: int | None, + expected_concurrency: int, +) -> None: + service_calls: list[dict[str, object]] = [] + command_module = importlib.import_module("commands.data_migrate") + + class FakeService: + def __init__( + self, + *, + engine: sa.Engine, + apply: bool, + concurrency: int, + output: io.TextIOBase | None = None, + tables: tuple[str, ...] | None, + model_types: tuple[ModelType, ...], + tenant_ids: tuple[str, ...] | None, + ) -> None: + service_calls.append( + { + "engine": engine, + "apply": apply, + "concurrency": concurrency, + "output": output, + "tables": tables, + "model_types": model_types, + "tenant_ids": tenant_ids, + } + ) + + def migrate(self) -> None: + service_calls.append({"migrated": True}) + + monkeypatch.setattr(os, "cpu_count", lambda: cpu_count) + importlib.reload(command_module) + try: + monkeypatch.setattr(command_module, "LegacyModelTypeMigrationService", FakeService) + monkeypatch.setattr(command_module, "db", SimpleNamespace(engine=object())) + + result = CliRunner().invoke(command_module.data_migrate, ["legacy-model-types"]) + + assert result.exit_code == 0, result.output + assert expected_concurrency == command_module._DEFAULT_CONCURRENCY + assert service_calls[0]["concurrency"] == expected_concurrency + assert service_calls[1] == {"migrated": True} + finally: + monkeypatch.undo() + importlib.reload(command_module) + + +def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reverse_dependency_expansion( + migration_module, + sqlite_engine: sa.Engine, +) -> None: + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch = pytest.MonkeyPatch() + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + concurrency=1, + tables=("provider_models", "tenant_default_models"), + model_types=(ModelType.LLM,), + tenant_ids=("tenant-alpha", "tenant-beta"), + ) + + service.migrate() + finally: + monkeypatch.undo() + + assert seen_runs == [ + ("tenant-alpha", ("provider_models", "tenant_default_models"), (ModelType.LLM,)), + ("tenant-beta", ("provider_models", "tenant_default_models"), (ModelType.LLM,)), + ] + + +def test_service_migrate_without_tenant_ids_discovers_tenants_per_selected_table_without_querying_tenants( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + provider_tenant_id = "00000000-0000-0000-0000-000000000111" + default_tenant_id = "00000000-0000-0000-0000-000000000222" + empty_tenant_id = "00000000-0000-0000-0000-000000000333" + for tenant_id in (provider_tenant_id, default_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 1, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_provider_model( + sqlite_engine, + row_id="10000000-0000-0000-0000-000000000111", + tenant_id=provider_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + _insert_tenant_default_model( + sqlite_engine, + row_id="20000000-0000-0000-0000-000000000222", + tenant_id=default_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + executed_sql: list[str] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + def _record_sql( + conn: sa.engine.Connection, + cursor: object, + statement: str, + parameters: object, + context: object, + executemany: bool, + ) -> None: + del conn, cursor, parameters, context, executemany + executed_sql.append(statement) + + sa.event.listen(sqlite_engine, "before_cursor_execute", _record_sql) + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("provider_models", "tenant_default_models"), + model_types=(ModelType.LLM,), + ) + + service.migrate() + finally: + sa.event.remove(sqlite_engine, "before_cursor_execute", _record_sql) + + assert seen_runs == [ + (provider_tenant_id, ("provider_models",), (ModelType.LLM,)), + (default_tenant_id, ("tenant_default_models",), (ModelType.LLM,)), + ] + normalized_statements = [" ".join(statement.lower().split()) for statement in executed_sql] + discovery_statements = [statement for statement in normalized_statements if statement.startswith("select")] + table_names = ("provider_models", "tenant_default_models") + table_discovery_statements = [ + statement + for statement in discovery_statements + if any(f" from {table_name} " in f" {statement} " for table_name in table_names) + ] + + assert [statement for statement in discovery_statements if " from tenants " in f" {statement} "] == [] + assert [statement for statement in discovery_statements if " union " in f" {statement} "] == [] + assert [ + next(table_name for table_name in table_names if f" from {table_name} " in f" {statement} ") + for statement in table_discovery_statements + ] == list(table_names) + + +def test_service_migrate_without_tenant_ids_filters_provider_model_tenants_by_selected_model_types( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + llm_tenant_id = "00000000-0000-0000-0000-000000000411" + embedding_tenant_id = "00000000-0000-0000-0000-000000000422" + empty_tenant_id = "00000000-0000-0000-0000-000000000433" + for tenant_id in (llm_tenant_id, embedding_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 2, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_provider_model( + sqlite_engine, + row_id="30000000-0000-0000-0000-000000000411", + tenant_id=llm_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + _insert_provider_model( + sqlite_engine, + row_id="30000000-0000-0000-0000-000000000422", + tenant_id=embedding_tenant_id, + provider_name="openai", + model_name="text-embedding-3-large", + model_type="embeddings", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("provider_models",), + model_types=(ModelType.LLM,), + ) + + service.migrate() + + assert seen_runs == [ + (llm_tenant_id, ("provider_models",), (ModelType.LLM,)), + ] + + +def test_service_migrate_without_tenant_ids_discovers_all_load_balancing_tenants_for_simpler_table_scoped_query( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + inherit_llm_tenant_id = "00000000-0000-0000-0000-000000000511" + inherit_embedding_tenant_id = "00000000-0000-0000-0000-000000000522" + empty_tenant_id = "00000000-0000-0000-0000-000000000533" + for tenant_id in (inherit_llm_tenant_id, inherit_embedding_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 3, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_load_balancing_model_config( + sqlite_engine, + row_id="40000000-0000-0000-0000-000000000511", + tenant_id=inherit_llm_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name="__inherit__", + encrypted_config=json.dumps({"api_key": "inherit-llm"}), + credential_id="50000000-0000-0000-0000-000000000511", + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + _insert_load_balancing_model_config( + sqlite_engine, + row_id="40000000-0000-0000-0000-000000000522", + tenant_id=inherit_embedding_tenant_id, + provider_name="openai", + model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING.value, + name="__inherit__", + encrypted_config=json.dumps({"api_key": "inherit-embedding"}), + credential_id="50000000-0000-0000-0000-000000000522", + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + # Load-balancing tenant discovery is a deliberate exception: it scans the + # whole table so the discovery query stays easy to understand, even when + # the scheduled tenant set is wider than the selected model types. + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + ) + + service.migrate() + + assert seen_runs == [ + (inherit_llm_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)), + (inherit_embedding_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)), + ] + + +def test_service_migrate_with_concurrency_greater_than_one_runs_tenants_in_parallel_without_changing_migration_scope( + migration_module, + sqlite_engine: sa.Engine, +) -> None: + init_calls: list[dict[str, object]] = [] + started_tenants: list[str] = [] + worker_errors: list[BaseException] = [] + release_runs = threading.Event() + all_started = threading.Event() + active_runs = 0 + max_active_runs = 0 + state_lock = threading.Lock() + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + self._tenant_id = tenant_id + init_calls.append( + { + "tenant_id": tenant_id, + "engine": engine, + "apply": apply, + "model_types": model_types, + "table_names": tuple(model.__table__.name for model in orm_models), + } + ) + + def run(self) -> None: + nonlocal active_runs, max_active_runs + with state_lock: + active_runs += 1 + max_active_runs = max(max_active_runs, active_runs) + started_tenants.append(self._tenant_id) + if len(started_tenants) == 2: + all_started.set() + + release_runs.wait(timeout=1) + + with state_lock: + active_runs -= 1 + + monkeypatch = pytest.MonkeyPatch() + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + concurrency=2, + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=("tenant-alpha", "tenant-beta"), + ) + + def _run_service() -> None: + try: + service.migrate() + except BaseException as exc: # pragma: no cover - test harness + worker_errors.append(exc) + + worker = threading.Thread(target=_run_service) + worker.start() + started_in_parallel = all_started.wait(timeout=0.5) + release_runs.set() + worker.join(timeout=1) + finally: + monkeypatch.undo() + + assert worker_errors == [] + assert started_in_parallel is True + assert worker.is_alive() is False + assert max_active_runs == 2 + assert {call["tenant_id"] for call in init_calls} == {"tenant-alpha", "tenant-beta"} + for call in init_calls: + assert tuple(cast(tuple[str, ...], call["table_names"])) == ("provider_models",) + assert call["model_types"] == (ModelType.LLM,) + + +def test_service_parallel_migrate_serializes_shared_output_by_line( + migration_module, + sqlite_engine: sa.Engine, +) -> None: + worker_errors: list[BaseException] = [] + start_barrier = threading.Barrier(2) + + class SlowLineOutput(io.StringIO): + def __init__(self) -> None: + super().__init__() + self.overlap_count = 0 + self._in_write = False + self._state_lock = threading.Lock() + + def write(self, s: str) -> int: + with self._state_lock: + if self._in_write: + self.overlap_count += 1 + self._in_write = True + try: + time.sleep(0.01) + return super().write(s) + finally: + with self._state_lock: + self._in_write = False + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + self._tenant_id = tenant_id + self._output = output + + def run(self) -> None: + try: + start_barrier.wait(timeout=1) + except threading.BrokenBarrierError as exc: + raise AssertionError("parallel migrate should schedule both tenant runs together") from exc + + for index in range(3): + self._output.write(f"{self._tenant_id}:line-{index}\n") + + monkeypatch = pytest.MonkeyPatch() + output = SlowLineOutput() + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + concurrency=2, + output=output, + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=("tenant-alpha", "tenant-beta"), + ) + + def _run_service() -> None: + try: + service.migrate() + except BaseException as exc: # pragma: no cover - test harness + worker_errors.append(exc) + + worker = threading.Thread(target=_run_service) + worker.start() + worker.join(timeout=2) + finally: + monkeypatch.undo() + + assert worker.is_alive() is False + assert worker_errors == [] + assert output.overlap_count == 0 + assert sorted(output.getvalue().splitlines()) == sorted( + [ + "tenant-alpha:line-0", + "tenant-alpha:line-1", + "tenant-alpha:line-2", + "tenant-beta:line-0", + "tenant-beta:line-1", + "tenant-beta:line-2", + ] + ) + + +def test_migration_dry_run_emits_json_lines_without_db_or_cache_mutation( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + before = snapshot_legacy_model_type_state(sqlite_engine) + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + output = io.StringIO() + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + after = snapshot_legacy_model_type_state(sqlite_engine) + assert after == before + assert deleted_cache_keys == [] + + lines = [json.loads(line) for line in output.getvalue().splitlines() if line.strip()] + assert lines, "dry-run should emit JSON lines" + assert all({"event", "message", "attrs", "ts"} <= set(line) for line in lines) + rendered_output = output.getvalue() + assert dirty_fixture.primary.loser_credential_id in rendered_output + assert dirty_fixture.primary.loser_credential_name in rendered_output + assert dirty_fixture.primary.loser_encrypted_config in rendered_output + + +def test_dry_run_and_apply_share_processing_scope_and_differ_only_on_side_effects( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + before = snapshot_legacy_model_type_state(sqlite_engine) + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=dry_run_output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + after_dry_run = snapshot_legacy_model_type_state(sqlite_engine) + dry_run_lines = _parse_json_lines(dry_run_output) + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=apply_output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + after_apply = snapshot_legacy_model_type_state(sqlite_engine) + apply_lines = _parse_json_lines(apply_output) + + assert after_dry_run == before + assert after_apply != before + + dry_run_signatures = _collect_processing_signatures(dry_run_lines) + apply_signatures = _collect_processing_signatures(apply_lines) + assert apply_signatures == dry_run_signatures + + dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")] + apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")] + assert "cache_deleted" not in dry_run_cache_events + assert "cache_delete_planned" in dry_run_cache_events + assert "cache_deleted" in apply_cache_events + assert deleted_cache_keys + + dry_run_rewrite_signatures = { + signature + for signature in dry_run_signatures + if signature[0] == "row_updated" + and signature[1] in {"provider_models", "load_balancing_model_configs"} + and signature[-1] != _json_key(None) + } + apply_rewrite_signatures = { + signature + for signature in apply_signatures + if signature[0] == "row_updated" + and signature[1] in {"provider_models", "load_balancing_model_configs"} + and signature[-1] != _json_key(None) + } + assert apply_rewrite_signatures == dry_run_rewrite_signatures + + dry_run_lb_signatures = { + signature + for signature in dry_run_signatures + if signature[0] == "row_updated" and signature[1] == "load_balancing_model_configs" + } + apply_lb_signatures = { + signature + for signature in apply_signatures + if signature[0] == "row_updated" and signature[1] == "load_balancing_model_configs" + } + assert apply_lb_signatures == dry_run_lb_signatures + + +def test_provider_models_processing_uses_same_plan_locking_and_transaction_entry_for_dry_run_and_apply( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + dry_migration = migration_module.Migration( + tenant_id=dirty_fixture.primary.tenant_id, + engine=sqlite_engine, + apply=False, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(ProviderModel,), + ) + candidate = dry_migration._load_provider_model_candidates(None)[0] + business_key = migration_module._ProviderModelBusinessKey( + tenant_id=candidate.row.tenant_id, + provider_name=candidate.row.provider_name, + model_name=candidate.row.model_name, + model_type=candidate.canonical_model_type, + ) + + apply_migration = migration_module.Migration( + tenant_id=dirty_fixture.primary.tenant_id, + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(ProviderModel,), + ) + + current_phase = {"name": "dry"} + lock_rows_seen: list[tuple[str, bool]] = [] + begin_calls: list[str] = [] + configure_calls: list[str] = [] + + class _FakeBeginContext: + def __init__(self, phase: str) -> None: + self._phase = phase + + def __enter__(self) -> None: + begin_calls.append(self._phase) + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + class _FakeSession: + def __init__(self, phase: str) -> None: + self._phase = phase + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def begin(self) -> _FakeBeginContext: + return _FakeBeginContext(self._phase) + + def _fake_session_factory(engine: sa.Engine) -> _FakeSession: + return _FakeSession(current_phase["name"]) + + def _fake_build_plan(self, session, candidate, *, lock_rows: bool): + lock_rows_seen.append((current_phase["name"], lock_rows)) + return SimpleNamespace(group_row_ids=[str(candidate.row.id)], winner=None, loser_rows=[]) + + def _fake_emit_plan(self, plan, *, session, tx_id: str, business_key: dict[str, object]) -> None: + return None + + def _fake_configure(self, session) -> None: + configure_calls.append(current_phase["name"]) + + monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory) + monkeypatch.setattr(migration_module.Migration, "_build_provider_model_group_plan", _fake_build_plan) + monkeypatch.setattr(migration_module.Migration, "_emit_provider_model_group_plan", _fake_emit_plan) + monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", _fake_configure) + + dry_migration._process_provider_model_group(candidate, business_key) + current_phase["name"] = "apply" + apply_migration._process_provider_model_group(candidate, business_key) + + assert [phase for phase, _ in lock_rows_seen] == ["dry", "apply"] + assert lock_rows_seen[0][1] == lock_rows_seen[1][1] + assert begin_calls == ["dry", "apply"] + assert configure_calls == ["dry", "apply"] + + +@pytest.mark.parametrize( + ("orig", "expected"), + [ + (SimpleNamespace(pgcode="55P03"), True), + (SimpleNamespace(sqlstate="55P03"), True), + (SimpleNamespace(errno=1205), True), + (RuntimeError("canceling statement due to lock timeout"), True), + (SimpleNamespace(pgcode="23505"), False), + (SimpleNamespace(errno=1213), False), + ], +) +def test_is_lock_timeout_error_prefers_structured_backend_codes( + migration_module, + sqlite_engine: sa.Engine, + orig: object, + expected: bool, +) -> None: + migration = migration_module.Migration( + tenant_id="tenant-1", + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(), + ) + exc = OperationalError("SELECT 1", {}, orig) + + assert migration._is_lock_timeout_error(exc) is expected + + +def test_process_load_balancing_model_config_row_logs_stacktrace_for_lock_timeout( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + output = io.StringIO() + migration = migration_module.Migration( + tenant_id="tenant-1", + engine=sqlite_engine, + apply=True, + output=output, + model_types=(ModelType.LLM,), + orm_models=(migration_module.LoadBalancingModelConfig,), + ) + candidate = migration_module._RowWithRawModelType( + row=SimpleNamespace(id="lb-row-1"), + raw_model_type="text-generation", + canonical_model_type=ModelType.LLM, + ) + lock_timeout_exc = OperationalError("SELECT 1", {}, SimpleNamespace(pgcode="55P03")) + + class _FakeBeginContext: + def __enter__(self) -> None: + return None + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + class _FakeSession: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def begin(self) -> _FakeBeginContext: + return _FakeBeginContext() + + def _fake_session_factory(engine: sa.Engine) -> _FakeSession: + return _FakeSession() + + def _fake_reload(self, session, original_candidate, *, lock_rows: bool): + raise lock_timeout_exc + + monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory) + monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", lambda self, session: None) + monkeypatch.setattr( + migration_module.Migration, + "_reload_load_balancing_model_config_candidate", + _fake_reload, + ) + + migration._process_load_balancing_model_config_row(candidate) + + lines = _parse_json_lines(output) + assert len(lines) == 1 + assert lines[0]["event"] == "lock_timeout_skipped" + attrs = cast(dict[str, object], lines[0]["attrs"]) + assert attrs["table_name"] == "load_balancing_model_configs" + assert attrs["id"] == "lb-row-1" + assert attrs["error"] == str(lock_timeout_exc) + assert isinstance(attrs["stacktrace"], str) + assert "OperationalError" in attrs["stacktrace"] + + +def test_process_load_balancing_model_config_row_logs_update_after_sql_execution( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + migration = migration_module.Migration( + tenant_id="tenant-1", + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + model_types=(ModelType.LLM,), + orm_models=(migration_module.LoadBalancingModelConfig,), + ) + candidate = migration_module._RowWithRawModelType( + row=SimpleNamespace(id="lb-row-1"), + raw_model_type="text-generation", + canonical_model_type=ModelType.LLM, + ) + action_log: list[str] = [] + + class _FakeBeginContext: + def __enter__(self) -> None: + action_log.append("begin") + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + class _FakeSession: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def begin(self) -> _FakeBeginContext: + return _FakeBeginContext() + + def execute(self, stmt) -> None: + action_log.append("sql_execute") + + def _fake_session_factory(engine: sa.Engine) -> _FakeSession: + return _FakeSession() + + def _fake_configure(self, session) -> None: + action_log.append("configure_lock_timeout") + + def _fake_reload(self, session, original_candidate, *, lock_rows: bool): + action_log.append(f"reload_candidate:{lock_rows}") + return candidate + + def _fake_log_row_updated(self, *args, **kwargs) -> None: + action_log.append("log_row_updated") + + def _fake_cache_cleanup(self, *, row_id: str, tx_id: str) -> None: + action_log.append("cache_cleanup") + + monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory) + monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", _fake_configure) + monkeypatch.setattr( + migration_module.Migration, + "_reload_load_balancing_model_config_candidate", + _fake_reload, + ) + monkeypatch.setattr(migration_module.Migration, "_log_row_updated", _fake_log_row_updated) + monkeypatch.setattr( + migration_module.Migration, + "_log_load_balancing_model_config_cache_cleanup", + _fake_cache_cleanup, + ) + + migration._process_load_balancing_model_config_row(candidate) + + assert action_log == [ + "begin", + "configure_lock_timeout", + "reload_candidate:True", + "sql_execute", + "log_row_updated", + "cache_cleanup", + ] + + +def test_load_balancing_model_config_cache_delete_failure_logs_stacktrace( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _raise_delete_failure(self) -> None: + raise RuntimeError("cache delete boom") + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _raise_delete_failure) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=output, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + failed_events = [ + cast(dict[str, object], line["attrs"]) + for line in _parse_json_lines(output) + if line.get("event") == "cache_delete_failed" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == "load_balancing_model_configs" + ] + + assert len(failed_events) == 1 + assert failed_events[0]["error"] == "cache delete boom" + assert isinstance(failed_events[0]["stacktrace"], str) + assert "RuntimeError: cache delete boom" in cast(str, failed_events[0]["stacktrace"]) + + +def test_group_completed_logs_exist_for_all_grouped_tables_and_use_canonical_model_type( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + output = io.StringIO() + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + lines = _parse_json_lines(output) + group_completed_records = [ + line + for line in lines + if isinstance(line.get("attrs"), dict) and "group_row_ids" in cast(dict[str, object], line["attrs"]) + ] + grouped_table_names = { + cast(dict[str, object], record["attrs"]).get("table_name") for record in group_completed_records + } + + assert grouped_table_names >= { + "provider_models", + "tenant_default_models", + "provider_model_settings", + "provider_model_credentials", + } + + for record in group_completed_records: + attrs = cast(dict[str, object], record["attrs"]) + business_key = cast(dict[str, object], attrs["business_key"]) + assert isinstance(attrs["group_row_ids"], list) + assert attrs["group_row_ids"] + if "model_type" in business_key: + assert business_key["model_type"] in { + ModelType.LLM.value, + ModelType.TEXT_EMBEDDING.value, + ModelType.RERANK.value, + } + assert business_key["model_type"] not in LEGACY_TO_CANONICAL + + +def test_provider_models_group_completed_log_includes_related_canonical_row_ids( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_batch_size(monkeypatch, migration_module, batch_size=1) + inserted_row_id = "00000000-0000-0000-0000-00000000aa01" + created_at = datetime(2025, 1, 1, 10, 0, 0) + updated_at = created_at + timedelta(minutes=5) + _insert_provider_model( + sqlite_engine, + row_id=inserted_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + credential_id=dirty_fixture.primary.distinct_credential_id, + created_at=created_at, + updated_at=updated_at, + ) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + lines = _parse_json_lines(output) + matching_records = [] + for line in lines: + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + business_key = attrs.get("business_key") + if not isinstance(business_key, dict): + continue + if ( + attrs.get("table_name") == "provider_models" + and business_key.get("tenant_id") == dirty_fixture.primary.tenant_id + and business_key.get("provider_name") == "openai" + and business_key.get("model_name") == "gpt-4o-mini" + and business_key.get("model_type") == ModelType.LLM.value + and "group_row_ids" in attrs + ): + matching_records.append(attrs) + + assert len(matching_records) == 1 + assert set(cast(list[str], matching_records[0]["group_row_ids"])) == { + dirty_fixture.primary.provider_model_id, + inserted_row_id, + } + + +def test_provider_model_settings_group_crossing_batches_is_completed_once_with_all_group_row_ids( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_batch_size(monkeypatch, migration_module, batch_size=1) + inserted_row_id = "00000000-0000-0000-0000-00000000cc01" + created_at = datetime(2025, 1, 1, 9, 0, 0) + updated_at = created_at + timedelta(minutes=10) + _insert_provider_model_setting( + sqlite_engine, + row_id=inserted_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + created_at=created_at, + updated_at=updated_at, + ) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=output, + tables=("provider_model_settings",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + lines = _parse_json_lines(output) + matching_records = [] + for line in lines: + attrs = line.get("attrs") + if not isinstance(attrs, dict): + continue + business_key = attrs.get("business_key") + if not isinstance(business_key, dict): + continue + if ( + attrs.get("table_name") == "provider_model_settings" + and business_key.get("tenant_id") == dirty_fixture.primary.tenant_id + and business_key.get("provider_name") == "openai" + and business_key.get("model_name") == "gpt-4o-mini" + and business_key.get("model_type") == ModelType.LLM.value + and "group_row_ids" in attrs + ): + matching_records.append(attrs) + + assert len(matching_records) == 1 + assert set(cast(list[str], matching_records[0]["group_row_ids"])) == { + dirty_fixture.primary.provider_model_setting_id, + inserted_row_id, + } + + +def test_load_balancing_inherit_rows_are_deduplicated_by_normalized_model_type_before_canonicalization( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + older_canonical_row_id = "00000000-0000-0000-0000-00000000dd01" + newer_legacy_row_id = "00000000-0000-0000-0000-00000000dd02" + created_at = datetime(2025, 1, 1, 8, 0, 0) + older_updated_at = created_at + timedelta(minutes=15) + newer_updated_at = created_at + timedelta(minutes=30) + _insert_load_balancing_model_config( + sqlite_engine, + row_id=older_canonical_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name="__inherit__", + encrypted_config='{"api_key":"older-inherit"}', + credential_id=dirty_fixture.primary.winner_credential_id, + enabled=True, + created_at=created_at, + updated_at=older_updated_at, + ) + _insert_load_balancing_model_config( + sqlite_engine, + row_id=newer_legacy_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + name="__inherit__", + encrypted_config='{"api_key":"newer-inherit"}', + credential_id=dirty_fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=newer_updated_at, + ) + + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + + tenant_id = dirty_fixture.primary.tenant_id + table_name = "load_balancing_model_configs" + expected_row_ids = {older_canonical_row_id, newer_legacy_row_id} + + dry_run_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + output=dry_run_output, + tables=(table_name,), + model_types=(ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + dry_run_lines = _parse_json_lines(dry_run_output) + dry_run_signatures = { + signature + for signature in _collect_processing_signatures(dry_run_lines) + if signature[1] == table_name and signature[2] in expected_row_ids + } + dry_run_row_updates = [ + cast(dict[str, object], line["attrs"]) + for line in dry_run_lines + if line.get("event") == "row_updated" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == table_name + and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids + ] + assert len(dry_run_row_updates) == 1 + assert str(dry_run_row_updates[0]["id"]) == newer_legacy_row_id + assert dry_run_row_updates[0]["old_values"] == {"model_type": "text-generation"} + assert dry_run_row_updates[0]["new_values"] == {"model_type": ModelType.LLM.value} + assert all("rewrite_source" not in attrs for attrs in dry_run_row_updates) + + dry_run_row_deletes = [ + cast(dict[str, object], line["attrs"]) + for line in dry_run_lines + if line.get("event") == "row_deleted" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == table_name + and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids + ] + assert len(dry_run_row_deletes) == 1 + assert dry_run_row_deletes[0]["business_key"] == { + "tenant_id": tenant_id, + "provider_name": "openai", + "model_name": "gpt-4o-mini", + "model_type": ModelType.LLM.value, + } + assert dry_run_row_deletes[0]["merge_winner_id"] == newer_legacy_row_id + assert dry_run_row_deletes[0]["row"] == { + "id": older_canonical_row_id, + "tenant_id": tenant_id, + "provider_name": "openai", + "model_name": "gpt-4o-mini", + "model_type": ModelType.LLM.value, + "name": "__inherit__", + "encrypted_config": {"api_key": "older-inherit"}, + "credential_id": dirty_fixture.primary.winner_credential_id, + "credential_source_type": CredentialSourceType.CUSTOM_MODEL.value, + "enabled": True, + "created_at": created_at.isoformat(), + "updated_at": older_updated_at.isoformat(), + } + + dry_run_deleted_index = next( + index + for index, line in enumerate(dry_run_lines) + if line.get("event") == "row_deleted" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("id") == older_canonical_row_id + ) + dry_run_updated_index = next( + index + for index, line in enumerate(dry_run_lines) + if line.get("event") == "row_updated" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("id") == newer_legacy_row_id + ) + assert dry_run_deleted_index < dry_run_updated_index + + dry_run_cache_plan_ids = _cache_event_row_ids( + dry_run_lines, + table_name=table_name, + row_ids=expected_row_ids, + event_name="cache_delete_planned", + ) + assert newer_legacy_row_id in dry_run_cache_plan_ids + + apply_output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=apply_output, + tables=(table_name,), + model_types=(ModelType.LLM,), + tenant_ids=(tenant_id,), + ).migrate() + + apply_lines = _parse_json_lines(apply_output) + apply_signatures = { + signature + for signature in _collect_processing_signatures(apply_lines) + if signature[1] == table_name and signature[2] in expected_row_ids + } + apply_row_updates = [ + cast(dict[str, object], line["attrs"]) + for line in apply_lines + if line.get("event") == "row_updated" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == table_name + and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids + ] + assert len(apply_row_updates) == 1 + assert str(apply_row_updates[0]["id"]) == newer_legacy_row_id + assert apply_signatures == dry_run_signatures + + apply_cache_delete_ids = _cache_event_row_ids( + apply_lines, + table_name=table_name, + row_ids=expected_row_ids, + event_name="cache_deleted", + ) + assert apply_cache_delete_ids == dry_run_cache_plan_ids + assert deleted_cache_keys + + lb_rows = fetch_table_rows(sqlite_engine, table_name, tenant_id=tenant_id) + surviving_rows = [row for row in lb_rows if str(row["id"]) in expected_row_ids] + assert len(surviving_rows) == 1 + surviving_row = surviving_rows[0] + assert surviving_row["id"] == newer_legacy_row_id + assert surviving_row["tenant_id"] == tenant_id + assert surviving_row["provider_name"] == "openai" + assert surviving_row["model_name"] == "gpt-4o-mini" + assert surviving_row["model_type"] == ModelType.LLM.value + assert surviving_row["name"] == "__inherit__" + assert surviving_row["encrypted_config"] == '{"api_key":"newer-inherit"}' + assert surviving_row["credential_id"] == dirty_fixture.primary.distinct_credential_id + assert surviving_row["credential_source_type"] == CredentialSourceType.CUSTOM_MODEL.value + + +def test_load_balancing_non_inherit_rows_do_not_participate_in_normalized_model_type_deduplication( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + inserted_row_id = "00000000-0000-0000-0000-00000000dd03" + created_at = datetime(2025, 1, 1, 8, 0, 0) + updated_at = created_at + timedelta(minutes=15) + _insert_load_balancing_model_config( + sqlite_engine, + row_id=inserted_row_id, + tenant_id=dirty_fixture.primary.tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name=dirty_fixture.primary.loser_credential_name, + encrypted_config='{"api_key":"second-lb"}', + credential_id=dirty_fixture.primary.distinct_credential_id, + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + + output = io.StringIO() + migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=output, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ).migrate() + + lines = _parse_json_lines(output) + row_deleted_events = [ + cast(dict[str, object], line["attrs"]) + for line in lines + if line.get("event") == "row_deleted" + and isinstance(line.get("attrs"), dict) + and cast(dict[str, object], line["attrs"]).get("table_name") == "load_balancing_model_configs" + ] + assert row_deleted_events == [] + + lb_rows = fetch_table_rows( + sqlite_engine, + "load_balancing_model_configs", + tenant_id=dirty_fixture.primary.tenant_id, + ) + matching_rows = [ + row for row in lb_rows if str(row["id"]) in {dirty_fixture.primary.load_balancing_config_id, inserted_row_id} + ] + assert len(matching_rows) == 2 + assert all(row["model_type"] == ModelType.LLM.value for row in matching_rows) + + +def test_migration_apply_updates_all_five_tables_and_rewrites_credential_references( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + deleted_cache_keys: list[str] = [] + + def _record_delete(self) -> None: + deleted_cache_keys.append(self.cache_key) + + monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete) + output = io.StringIO() + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=output, + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + assert_tenant_rows_use_only_canonical_model_types(sqlite_engine, dirty_fixture.primary.tenant_id) + + provider_model_rows = fetch_table_rows(sqlite_engine, "provider_models", tenant_id=dirty_fixture.primary.tenant_id) + provider_model_row = next( + row for row in provider_model_rows if row["id"] == dirty_fixture.primary.provider_model_id + ) + assert provider_model_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert provider_model_row["credential_id"] == dirty_fixture.primary.winner_credential_id + + lb_rows = fetch_table_rows(sqlite_engine, "load_balancing_model_configs", tenant_id=dirty_fixture.primary.tenant_id) + lb_row = next(row for row in lb_rows if row["id"] == dirty_fixture.primary.load_balancing_config_id) + assert lb_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert lb_row["credential_id"] == dirty_fixture.primary.winner_credential_id + assert lb_row["encrypted_config"] == dirty_fixture.primary.winner_encrypted_config + + credential_rows = fetch_table_rows( + sqlite_engine, "provider_model_credentials", tenant_id=dirty_fixture.primary.tenant_id + ) + assert ( + count_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + == 2 + ) + credential_ids = {str(row["id"]) for row in credential_rows} + assert credential_ids == { + dirty_fixture.primary.winner_credential_id, + dirty_fixture.primary.distinct_credential_id, + } + distinct_row = next(row for row in credential_rows if row["id"] == dirty_fixture.primary.distinct_credential_id) + assert distinct_row["credential_name"] == dirty_fixture.primary.distinct_credential_name + assert distinct_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + + rendered_output = output.getvalue() + assert dirty_fixture.primary.loser_credential_id in rendered_output + assert dirty_fixture.primary.loser_encrypted_config in rendered_output + assert any("load_balancing_provider_model_credentials" in key for key in deleted_cache_keys) or any( + "load_balancing_provider_model" in key for key in deleted_cache_keys + ) + + +def test_migration_filters_by_tenant_model_types_and_tables( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + before_primary_credentials = fetch_table_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + before_secondary = { + table_name: fetch_table_rows(sqlite_engine, table_name, tenant_id=dirty_fixture.secondary.tenant_id) + for table_name in ALL_TABLE_NAMES + } + + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + tables=("provider_models",), + model_types=(ModelType.LLM,), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + assert ( + count_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + == 3 + ) + credential_rows = fetch_table_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + assert credential_rows == before_primary_credentials + provider_model_rows = fetch_table_rows( + sqlite_engine, + "provider_models", + tenant_id=dirty_fixture.primary.tenant_id, + ) + provider_model_row = next( + row for row in provider_model_rows if row["id"] == dirty_fixture.primary.provider_model_id + ) + embedding_provider_model_row = next( + row for row in provider_model_rows if row["id"] == dirty_fixture.primary.embedding_provider_model_id + ) + assert provider_model_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert embedding_provider_model_row["model_type"] == "embeddings" + + tenant_default_row = fetch_table_rows( + sqlite_engine, + "tenant_default_models", + tenant_id=dirty_fixture.primary.tenant_id, + )[0] + assert tenant_default_row["model_type"] == "text-generation" + + provider_model_setting_rows = fetch_table_rows( + sqlite_engine, + "provider_model_settings", + tenant_id=dirty_fixture.primary.tenant_id, + ) + llm_setting_row = next( + row for row in provider_model_setting_rows if row["id"] == dirty_fixture.primary.provider_model_setting_id + ) + embedding_setting_row = next( + row for row in provider_model_setting_rows if row["id"] == dirty_fixture.primary.embedding_setting_id + ) + assert llm_setting_row["model_type"] == "text-generation" + assert embedding_setting_row["model_type"] == "embeddings" + + lb_row = fetch_table_rows( + sqlite_engine, + "load_balancing_model_configs", + tenant_id=dirty_fixture.primary.tenant_id, + )[0] + assert lb_row["model_type"] == "text-generation" + + after_secondary = { + table_name: fetch_table_rows(sqlite_engine, table_name, tenant_id=dirty_fixture.secondary.tenant_id) + for table_name in ALL_TABLE_NAMES + } + assert after_secondary == before_secondary + + +def test_migration_does_not_merge_credentials_with_different_credential_name( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + + credential_rows = fetch_table_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + distinct_row = next(row for row in credential_rows if row["id"] == dirty_fixture.primary.distinct_credential_id) + assert distinct_row["credential_name"] == dirty_fixture.primary.distinct_credential_name + assert distinct_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"] + assert ( + count_rows( + sqlite_engine, + "provider_model_credentials", + tenant_id=dirty_fixture.primary.tenant_id, + ) + == 2 + ) + + +def test_migration_is_idempotent_on_second_apply( + migration_module, + sqlite_engine: sa.Engine, + dirty_fixture, +) -> None: + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=True, + output=io.StringIO(), + tenant_ids=(dirty_fixture.primary.tenant_id,), + ) + + service.migrate() + after_first = snapshot_legacy_model_type_state(sqlite_engine) + + service.migrate() + after_second = snapshot_legacy_model_type_state(sqlite_engine) + + assert after_second == after_first