feat(api): introduce model-type migration script (#36520)

This commit is contained in:
QuantumGhost 2026-05-27 10:12:11 +08:00 committed by GitHub
parent dade318f00
commit 5c5a6e83e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 5523 additions and 5 deletions

View File

@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp):
def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
from extensions import ext_commands, ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
ext_commands.init_app(app)
return app

View File

@ -3,6 +3,7 @@ CLI command modules extracted from `commands.py`.
"""
from .account import create_tenant, reset_email, reset_password
from .data_migrate import data_migrate, legacy_model_types
from .plugin import (
extract_plugins,
extract_unique_plugins,
@ -44,6 +45,7 @@ __all__ = [
"clear_orphaned_file_records",
"convert_to_agent_apps",
"create_tenant",
"data_migrate",
"delete_archived_workflow_runs",
"export_app_messages",
"extract_plugins",
@ -52,6 +54,7 @@ __all__ = [
"fix_app_site_missing",
"install_plugins",
"install_rag_pipeline_plugins",
"legacy_model_types",
"migrate_annotation_vector_database",
"migrate_data_for_plugin",
"migrate_knowledge_vector_database",

View File

@ -0,0 +1,169 @@
import io
import os
import sys
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from typing import cast
import click
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from services.legacy_model_type_migration import (
VALID_TABLE_NAMES,
LegacyModelTypeMigrationService,
load_tenant_ids_from_file,
)
_SUPPORTED_MODEL_TYPE_CHOICES = (
ModelType.LLM.value,
ModelType.TEXT_EMBEDDING.value,
ModelType.RERANK.value,
)
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
def _normalize_multi_value_option(
values: tuple[str, ...],
*,
valid_values: tuple[str, ...],
option_name: str,
) -> tuple[str, ...]:
normalized_values: list[str] = []
seen_values: set[str] = set()
for value in values:
for item in value.split(","):
normalized_item = item.strip()
if not normalized_item:
continue
if normalized_item not in valid_values:
raise click.BadParameter(
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
param_hint=option_name,
)
if normalized_item in seen_values:
continue
seen_values.add(normalized_item)
normalized_values.append(normalized_item)
return tuple(normalized_values)
@click.group(
"data-migrate",
help="Online data migration commands.",
)
def data_migrate() -> None:
"""Namespace for production data migration commands."""
@click.command(
"legacy-model-types",
help=(
"Migrate legacy provider model_type values to canonical values. "
"Default is dry-run and emits JSON lines only. "
"If --tables includes provider_model_credentials, the command may also update "
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
),
)
@click.option(
"--apply",
is_flag=True,
default=False,
help="Apply the migration. Default is dry-run.",
)
@click.option(
"--tables",
"tables",
multiple=True,
type=str,
help=(
"Limit model_type migration to specific tables. Accepts comma-separated values or repeated flags. "
"When provider_model_credentials is selected, provider_models and "
"load_balancing_model_configs may also be updated for credential reference rewrites."
"Default to: "
),
)
@click.option(
"--model-types",
"model_types",
multiple=True,
type=str,
help=(
"Canonical model types to migrate. Accepts comma-separated values or repeated flags. "
"Defaults to: `llm,text-embedding,rerank`"
),
)
@click.option(
"--tenant-id-file",
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
help="Optional file containing tenant ids, one per line.",
)
@click.option(
"--output",
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
help="Optional file path for JSON lines event logs. Defaults to stdout.",
)
@click.option(
"--concurrency",
type=click.IntRange(min=1),
default=_DEFAULT_CONCURRENCY,
show_default=True,
help="Number of tenant-level worker threads to run in parallel.",
)
def legacy_model_types(
apply: bool,
tables: tuple[str, ...],
model_types: tuple[str, ...],
tenant_id_file: str | None,
output: Path | None,
concurrency: int = _DEFAULT_CONCURRENCY,
) -> None:
"""
Migrate legacy provider-related model_type values and emit JSON lines events.
"""
normalized_tables = _normalize_multi_value_option(
tables,
valid_values=VALID_TABLE_NAMES,
option_name="--tables",
)
normalized_model_types = _normalize_multi_value_option(
model_types,
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
option_name="--model-types",
)
selected_model_types = (
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
if normalized_model_types
else (
ModelType.LLM,
ModelType.TEXT_EMBEDDING,
ModelType.RERANK,
)
)
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
output_context: AbstractContextManager[io.TextIOBase]
if output is None:
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
else:
try:
output_context = output.open("w", encoding="utf-8")
except OSError as exc:
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
with output_context as output_stream:
LegacyModelTypeMigrationService(
engine=db.engine,
apply=apply,
concurrency=concurrency,
output=cast(io.TextIOBase, output_stream),
tables=normalized_tables or None,
model_types=selected_model_types,
tenant_ids=tenant_ids,
).migrate()
data_migrate.add_command(legacy_model_types)

View File

@ -12,6 +12,7 @@ def init_app(app: DifyApp):
clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
data_migrate,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
@ -44,6 +45,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
add_qdrant_index,
create_tenant,
data_migrate,
upgrade_db,
fix_app_site_missing,
migrate_data_for_plugin,

View File

@ -102,10 +102,7 @@ dify-trace-weave = { workspace = true }
[tool.uv]
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = [
"litellm>=1.83.10,<2.0.0",
"pyarrow>=23.0.1,<24.0.0",
]
override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"]
[dependency-groups]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
"""Shared test helpers for backend migration tests."""

View File

@ -0,0 +1,366 @@
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy.engine import Engine
from models.account import Tenant
from models.enums import CredentialSourceType
from models.provider import (
LoadBalancingModelConfig,
ProviderModel,
ProviderModelCredential,
ProviderModelSetting,
TenantDefaultModel,
)
LEGACY_TO_CANONICAL: dict[str, str] = {
"text-generation": "llm",
"embeddings": "text-embedding",
"reranking": "rerank",
}
UNCHANGED_MODEL_TYPES: tuple[str, ...] = ("speech2text", "moderation", "tts")
ALL_TABLE_NAMES: tuple[str, ...] = (
ProviderModel.__tablename__,
TenantDefaultModel.__tablename__,
ProviderModelSetting.__tablename__,
LoadBalancingModelConfig.__tablename__,
ProviderModelCredential.__tablename__,
)
DEFAULT_PRIMARY_TENANT_ID = "00000000-0000-0000-0000-000000000101"
DEFAULT_SECONDARY_TENANT_ID = "00000000-0000-0000-0000-000000000202"
@dataclass(frozen=True, slots=True)
class DirtyTenantFixture:
tenant_id: str
winner_credential_id: str
loser_credential_id: str
distinct_credential_id: str
provider_model_id: str
load_balancing_config_id: str
provider_model_setting_id: str
tenant_default_model_id: str
embedding_provider_model_id: str
embedding_setting_id: str
loser_credential_name: str
distinct_credential_name: str
loser_encrypted_config: str
winner_encrypted_config: str
@dataclass(frozen=True, slots=True)
class DirtyDataFixture:
primary: DirtyTenantFixture
secondary: DirtyTenantFixture
def create_minimal_legacy_model_type_schema(engine: Engine) -> None:
metadata = Tenant.__table__.metadata
metadata.create_all(
engine,
tables=[
Tenant.__table__,
ProviderModel.__table__,
TenantDefaultModel.__table__,
ProviderModelSetting.__table__,
LoadBalancingModelConfig.__table__,
ProviderModelCredential.__table__,
],
checkfirst=True,
)
def drop_minimal_legacy_model_type_schema(engine: Engine) -> None:
metadata = Tenant.__table__.metadata
metadata.drop_all(
engine,
tables=[
LoadBalancingModelConfig.__table__,
ProviderModelSetting.__table__,
TenantDefaultModel.__table__,
ProviderModel.__table__,
ProviderModelCredential.__table__,
Tenant.__table__,
],
checkfirst=True,
)
def seed_legacy_model_type_dirty_data(
engine: Engine,
*,
primary_tenant_id: str = DEFAULT_PRIMARY_TENANT_ID,
secondary_tenant_id: str = DEFAULT_SECONDARY_TENANT_ID,
) -> DirtyDataFixture:
create_minimal_legacy_model_type_schema(engine)
primary = _seed_tenant(engine, tenant_id=primary_tenant_id, provider_name="openai")
secondary = _seed_tenant(engine, tenant_id=secondary_tenant_id, provider_name="openai")
return DirtyDataFixture(primary=primary, secondary=secondary)
def snapshot_legacy_model_type_state(engine: Engine) -> dict[str, list[dict[str, object]]]:
snapshots: dict[str, list[dict[str, object]]] = {}
for table_name in ALL_TABLE_NAMES:
snapshots[table_name] = fetch_table_rows(engine, table_name)
return snapshots
def fetch_table_rows(
engine: Engine,
table_name: str,
*,
tenant_id: str | None = None,
) -> list[dict[str, object]]:
sql = f"SELECT * FROM {table_name}"
params: dict[str, object] = {}
if tenant_id is not None:
sql += " WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
sql += " ORDER BY id ASC"
with engine.begin() as conn:
rows = conn.execute(sa.text(sql), params).mappings().all()
result: list[dict[str, object]] = []
for row in rows:
normalized = dict(row)
for key, value in normalized.items():
if isinstance(value, datetime):
normalized[key] = value.isoformat()
elif isinstance(value, uuid.UUID):
normalized[key] = str(value)
result.append(normalized)
return result
def fetch_model_types_for_tenant(engine: Engine, table_name: str, tenant_id: str) -> list[str]:
rows = fetch_table_rows(engine, table_name, tenant_id=tenant_id)
return [str(row["model_type"]) for row in rows]
def assert_tenant_rows_use_only_canonical_model_types(engine: Engine, tenant_id: str) -> None:
for table_name in ALL_TABLE_NAMES:
model_types = fetch_model_types_for_tenant(engine, table_name, tenant_id)
assert set(model_types) <= set(LEGACY_TO_CANONICAL.values()) | set(UNCHANGED_MODEL_TYPES), (
table_name,
model_types,
)
def count_rows(engine: Engine, table_name: str, *, tenant_id: str) -> int:
with engine.begin() as conn:
stmt = sa.text(f"SELECT COUNT(*) FROM {table_name} WHERE tenant_id = :tenant_id")
return int(conn.execute(stmt, {"tenant_id": tenant_id}).scalar_one())
def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> DirtyTenantFixture:
now = datetime(2025, 1, 1, 12, 0, 0)
winner_credential_id = str(uuid4())
loser_credential_id = str(uuid4())
distinct_credential_id = str(uuid4())
provider_model_id = str(uuid4())
load_balancing_config_id = str(uuid4())
provider_model_setting_id = str(uuid4())
tenant_default_model_id = str(uuid4())
embedding_provider_model_id = str(uuid4())
embedding_setting_id = str(uuid4())
loser_credential_name = f"{tenant_id}-shared"
distinct_credential_name = f"{tenant_id}-distinct"
winner_encrypted_config = json.dumps({"api_key": f"{tenant_id}-winner"})
loser_encrypted_config = json.dumps({"api_key": f"{tenant_id}-loser"})
distinct_encrypted_config = json.dumps({"api_key": f"{tenant_id}-distinct"})
with engine.begin() as conn:
conn.execute(
Tenant.__table__.insert().values(
id=tenant_id,
name=f"Tenant {tenant_id}",
plan="basic",
status="normal",
)
)
conn.execute(
sa.text(
"""
INSERT INTO provider_model_credentials
(
id, tenant_id, provider_name, model_name,
model_type, credential_name, encrypted_config,
created_at, updated_at
)
VALUES
(
:winner_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'llm', :shared_name, :winner_config,
:created_at, :winner_updated_at
),
(
:loser_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :shared_name, :loser_config,
:created_at, :loser_updated_at
),
(
:distinct_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :distinct_name, :distinct_config,
:created_at, :distinct_updated_at
)
"""
),
{
"winner_id": winner_credential_id,
"loser_id": loser_credential_id,
"distinct_id": distinct_credential_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"shared_name": loser_credential_name,
"distinct_name": distinct_credential_name,
"winner_config": winner_encrypted_config,
"loser_config": loser_encrypted_config,
"distinct_config": distinct_encrypted_config,
"created_at": now - timedelta(days=2),
"winner_updated_at": now,
"loser_updated_at": now - timedelta(days=1),
"distinct_updated_at": now - timedelta(hours=12),
},
)
conn.execute(
sa.text(
"""
INSERT INTO provider_models
(
id, tenant_id, provider_name, model_name,
model_type, credential_id, is_valid,
created_at, updated_at
)
VALUES
(
:provider_model_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :loser_id, :is_valid,
:created_at, :updated_at
),
(
:embedding_provider_model_id, :tenant_id, :provider_name, 'text-embedding-3-large',
'embeddings', NULL, :is_valid,
:created_at, :updated_at
)
"""
),
{
"provider_model_id": provider_model_id,
"embedding_provider_model_id": embedding_provider_model_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"loser_id": loser_credential_id,
"is_valid": True,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=6),
},
)
conn.execute(
sa.text(
"""
INSERT INTO tenant_default_models
(id, tenant_id, provider_name, model_name, model_type, created_at, updated_at)
VALUES
(
:tenant_default_model_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :created_at, :updated_at
)
"""
),
{
"tenant_default_model_id": tenant_default_model_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=4),
},
)
conn.execute(
sa.text(
"""
INSERT INTO provider_model_settings
(
id, tenant_id, provider_name, model_name,
model_type, enabled, load_balancing_enabled,
created_at, updated_at
)
VALUES
(
:provider_model_setting_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :enabled, :load_balancing_enabled,
:created_at, :updated_at
),
(
:embedding_setting_id, :tenant_id, :provider_name, 'text-embedding-3-large',
'embeddings', :enabled, :embedding_load_balancing_enabled,
:created_at, :updated_at
)
"""
),
{
"provider_model_setting_id": provider_model_setting_id,
"embedding_setting_id": embedding_setting_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"enabled": True,
"load_balancing_enabled": True,
"embedding_load_balancing_enabled": False,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=3),
},
)
conn.execute(
sa.text(
"""
INSERT INTO load_balancing_model_configs
(
id, tenant_id, provider_name, model_name, model_type,
name, encrypted_config, credential_id, credential_source_type,
enabled, created_at, updated_at
)
VALUES
(
:load_balancing_config_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation',
:lb_name, :loser_config, :loser_id, :credential_source_type,
:enabled, :created_at, :updated_at
)
"""
),
{
"load_balancing_config_id": load_balancing_config_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"lb_name": loser_credential_name,
"loser_config": loser_encrypted_config,
"loser_id": loser_credential_id,
"credential_source_type": CredentialSourceType.CUSTOM_MODEL.value,
"enabled": True,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=2),
},
)
return DirtyTenantFixture(
tenant_id=tenant_id,
winner_credential_id=winner_credential_id,
loser_credential_id=loser_credential_id,
distinct_credential_id=distinct_credential_id,
provider_model_id=provider_model_id,
load_balancing_config_id=load_balancing_config_id,
provider_model_setting_id=provider_model_setting_id,
tenant_default_model_id=tenant_default_model_id,
embedding_provider_model_id=embedding_provider_model_id,
embedding_setting_id=embedding_setting_id,
loser_credential_name=loser_credential_name,
distinct_credential_name=distinct_credential_name,
loser_encrypted_config=loser_encrypted_config,
winner_encrypted_config=winner_encrypted_config,
)

View File

@ -0,0 +1,82 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
API_PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(API_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(API_PROJECT_ROOT))
import sqlalchemy as sa
from tests.helpers.legacy_model_type_migration import (
DEFAULT_PRIMARY_TENANT_ID,
DEFAULT_SECONDARY_TENANT_ID,
create_minimal_legacy_model_type_schema,
seed_legacy_model_type_dirty_data,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Seed dirty legacy model_type rows for manual migration experiments. "
"Example: uv run --project api python api/tests/seed_legacy_model_type_dirty_data.py "
"--db-url postgresql://postgres:postgres@127.0.0.1:5432/dify"
)
)
parser.add_argument("--db-url", required=True, help="SQLAlchemy database URL for the target database.")
parser.add_argument(
"--primary-tenant-id",
default=DEFAULT_PRIMARY_TENANT_ID,
help="Tenant that will contain the main conflict scenario.",
)
parser.add_argument(
"--secondary-tenant-id",
default=DEFAULT_SECONDARY_TENANT_ID,
help="Tenant used to verify tenant filtering behavior.",
)
parser.add_argument(
"--create-minimal-schema",
action="store_true",
help="Create the minimal tables needed for the seed when running against an empty scratch database.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
engine = sa.create_engine(args.db_url)
try:
if args.create_minimal_schema:
create_minimal_legacy_model_type_schema(engine)
fixture = seed_legacy_model_type_dirty_data(
engine,
primary_tenant_id=args.primary_tenant_id,
secondary_tenant_id=args.secondary_tenant_id,
)
finally:
engine.dispose()
print(
json.dumps(
{
"primary_tenant_id": fixture.primary.tenant_id,
"secondary_tenant_id": fixture.secondary.tenant_id,
"winner_credential_id": fixture.primary.winner_credential_id,
"loser_credential_id": fixture.primary.loser_credential_id,
"provider_model_id": fixture.primary.provider_model_id,
"load_balancing_config_id": fixture.primary.load_balancing_config_id,
},
indent=2,
sort_keys=True,
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,408 @@
from __future__ import annotations
import importlib
import io
import json
from collections.abc import Generator
from datetime import datetime, timedelta
import pytest
import sqlalchemy as sa
from tests.helpers.legacy_model_type_migration import (
assert_tenant_rows_use_only_canonical_model_types,
count_rows,
fetch_table_rows,
seed_legacy_model_type_dirty_data,
)
def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]:
return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()]
def _json_key(value: object) -> str:
return json.dumps(value, sort_keys=True)
def _lb_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]:
signatures: set[tuple[object, ...]] = set()
for line in lines:
attrs = line.get("attrs")
if not isinstance(attrs, dict):
continue
if attrs.get("table_name") != "load_balancing_model_configs":
continue
event = line.get("event")
if event == "row_updated":
signatures.add(
(
event,
attrs.get("id"),
_json_key(attrs.get("old_values")),
_json_key(attrs.get("new_values")),
)
)
elif event == "row_deleted":
signatures.add(
(
event,
attrs.get("id"),
attrs.get("merge_winner_id"),
)
)
elif event == "group_processed":
signatures.add(
(
event,
attrs.get("table_name"),
_json_key(attrs.get("business_key")),
tuple(attrs.get("group_row_ids", [])),
)
)
return signatures
def _insert_load_balancing_model_config(
engine: sa.Engine,
*,
row_id: str,
tenant_id: str,
provider_name: str,
model_name: str,
model_type: str,
name: str,
encrypted_config: str,
credential_id: str,
enabled: bool,
created_at: datetime,
updated_at: datetime,
) -> None:
with engine.begin() as conn:
conn.execute(
sa.text(
"""
INSERT INTO load_balancing_model_configs
(
id, tenant_id, provider_name, model_name, model_type, name,
encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at
)
VALUES
(
:id, :tenant_id, :provider_name, :model_name, :model_type, :name,
:encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at
)
"""
),
{
"id": row_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"model_name": model_name,
"model_type": model_type,
"name": name,
"encrypted_config": encrypted_config,
"credential_id": credential_id,
"credential_source_type": "custom_model",
"enabled": enabled,
"created_at": created_at,
"updated_at": updated_at,
},
)
@pytest.fixture(scope="session")
def migration_module():
try:
return importlib.import_module("services.legacy_model_type_migration")
except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path
pytest.fail(
"services.legacy_model_type_migration is missing. "
"Implement LegacyModelTypeMigrationService before running these tests."
)
@pytest.fixture(params=("postgresql", "mysql"), scope="session")
def container_engine(request: pytest.FixtureRequest) -> Generator[tuple[str, sa.Engine], None, None]:
backend_name = request.param
if backend_name == "postgresql":
testcontainers_postgres = pytest.importorskip("testcontainers.postgres")
container = testcontainers_postgres.PostgresContainer("postgres:15-alpine")
else:
testcontainers_mysql = pytest.importorskip("testcontainers.mysql")
container = testcontainers_mysql.MySqlContainer("mysql:8.0")
container.start()
raw_url = container.get_connection_url()
engine_url = raw_url.replace("mysql://", "mysql+pymysql://", 1)
engine = sa.create_engine(engine_url)
try:
yield backend_name, engine
finally:
engine.dispose()
container.stop()
def test_legacy_model_type_migration_end_to_end_across_supported_backends(
migration_module,
container_engine: tuple[str, sa.Engine],
monkeypatch: pytest.MonkeyPatch,
) -> None:
backend_name, engine = container_engine
helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration")
helper_module.drop_minimal_legacy_model_type_schema(engine)
fixture = seed_legacy_model_type_dirty_data(engine)
deleted_cache_keys: list[str] = []
def _record_delete(self) -> None:
deleted_cache_keys.append(self.cache_key)
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete)
dry_run_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=False,
output=dry_run_output,
tenant_ids=(fixture.primary.tenant_id,),
).migrate()
assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 3
assert deleted_cache_keys == []
apply_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=True,
output=apply_output,
tenant_ids=(fixture.primary.tenant_id,),
).migrate()
first_apply_state = {
table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id)
for table_name in (
"provider_models",
"tenant_default_models",
"provider_model_settings",
"load_balancing_model_configs",
"provider_model_credentials",
)
}
assert_tenant_rows_use_only_canonical_model_types(engine, fixture.primary.tenant_id)
assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 2
provider_model_row = next(
row for row in first_apply_state["provider_models"] if row["id"] == fixture.primary.provider_model_id
)
assert provider_model_row["credential_id"] == fixture.primary.winner_credential_id
credential_ids = {str(row["id"]) for row in first_apply_state["provider_model_credentials"]}
assert credential_ids == {
fixture.primary.winner_credential_id,
fixture.primary.distinct_credential_id,
}
lb_row = next(
row
for row in first_apply_state["load_balancing_model_configs"]
if row["id"] == fixture.primary.load_balancing_config_id
)
assert lb_row["credential_id"] == fixture.primary.winner_credential_id
assert lb_row["encrypted_config"] == fixture.primary.winner_encrypted_config
assert deleted_cache_keys, f"{backend_name} apply run should clear cache keys"
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=True,
output=io.StringIO(),
tenant_ids=(fixture.primary.tenant_id,),
).migrate()
second_apply_state = {
table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id)
for table_name in first_apply_state
}
assert second_apply_state == first_apply_state
def test_load_balancing_inherit_deduplication_is_applied_consistently_across_supported_backends(
migration_module,
container_engine: tuple[str, sa.Engine],
monkeypatch: pytest.MonkeyPatch,
) -> None:
_, engine = container_engine
helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration")
helper_module.drop_minimal_legacy_model_type_schema(engine)
fixture = seed_legacy_model_type_dirty_data(engine)
tenant_id = fixture.primary.tenant_id
older_inherit_row_id = "00000000-0000-0000-0000-00000000ee01"
newer_inherit_row_id = "00000000-0000-0000-0000-00000000ee02"
canonical_non_inherit_row_id = "00000000-0000-0000-0000-00000000ee03"
created_at = datetime(2025, 1, 1, 8, 0, 0)
_insert_load_balancing_model_config(
engine,
row_id=older_inherit_row_id,
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="llm",
name="__inherit__",
encrypted_config='{"api_key":"older-inherit"}',
credential_id=fixture.primary.winner_credential_id,
enabled=True,
created_at=created_at,
updated_at=created_at + timedelta(minutes=15),
)
_insert_load_balancing_model_config(
engine,
row_id=newer_inherit_row_id,
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
name="__inherit__",
encrypted_config='{"api_key":"newer-inherit"}',
credential_id=fixture.primary.distinct_credential_id,
enabled=True,
created_at=created_at,
updated_at=created_at + timedelta(minutes=30),
)
_insert_load_balancing_model_config(
engine,
row_id=canonical_non_inherit_row_id,
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="llm",
name=f"{tenant_id}-second-shared",
encrypted_config='{"api_key":"non-inherit-canonical"}',
credential_id=fixture.primary.distinct_credential_id,
enabled=True,
created_at=created_at,
updated_at=created_at + timedelta(minutes=45),
)
before_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
deleted_cache_keys: list[str] = []
def _record_delete(self) -> None:
deleted_cache_keys.append(self.cache_key)
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete)
dry_run_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=False,
output=dry_run_output,
tables=("load_balancing_model_configs",),
model_types=(migration_module.ModelType.LLM,),
tenant_ids=(tenant_id,),
).migrate()
after_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
dry_run_lines = _parse_json_lines(dry_run_output)
dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")]
dry_run_row_updates = {
str(attrs["id"])
for line in dry_run_lines
if line.get("event") == "row_updated"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
}
dry_run_row_deletes = {
str(attrs["id"])
for line in dry_run_lines
if line.get("event") == "row_deleted"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
}
dry_run_group_processed = [
attrs
for line in dry_run_lines
if line.get("event") == "group_processed"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
]
assert after_dry_run == before_dry_run
assert deleted_cache_keys == []
assert dry_run_row_deletes == {older_inherit_row_id}
assert dry_run_row_updates == {
fixture.primary.load_balancing_config_id,
newer_inherit_row_id,
}
assert canonical_non_inherit_row_id not in dry_run_row_updates
assert "cache_delete_planned" in dry_run_cache_events
assert "cache_deleted" not in dry_run_cache_events
assert len(dry_run_group_processed) == 1
assert dry_run_group_processed[0]["table_name"] == "load_balancing_model_configs"
assert dry_run_group_processed[0]["business_key"] == {
"tenant_id": tenant_id,
"provider_name": "openai",
"model_name": "gpt-4o-mini",
"model_type": "llm",
}
assert set(dry_run_group_processed[0]["group_row_ids"]) == {
older_inherit_row_id,
newer_inherit_row_id,
}
apply_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=True,
output=apply_output,
tables=("load_balancing_model_configs",),
model_types=(migration_module.ModelType.LLM,),
tenant_ids=(tenant_id,),
).migrate()
apply_lines = _parse_json_lines(apply_output)
apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")]
apply_group_processed = [
attrs
for line in apply_lines
if line.get("event") == "group_processed"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
]
assert _lb_processing_signatures(apply_lines) == _lb_processing_signatures(dry_run_lines)
assert "cache_deleted" in apply_cache_events
assert deleted_cache_keys
assert len(apply_group_processed) == len(dry_run_group_processed)
assert [
(
attrs["table_name"],
_json_key(attrs["business_key"]),
tuple(attrs["group_row_ids"]),
)
for attrs in apply_group_processed
] == [
(
attrs["table_name"],
_json_key(attrs["business_key"]),
tuple(attrs["group_row_ids"]),
)
for attrs in dry_run_group_processed
]
lb_rows = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
surviving_inherit_rows = [row for row in lb_rows if row["name"] == "__inherit__"]
surviving_non_inherit_rows = [row for row in lb_rows if row["name"] != "__inherit__"]
assert {str(row["id"]) for row in surviving_inherit_rows} == {newer_inherit_row_id}
assert surviving_inherit_rows[0]["model_type"] == "llm"
assert surviving_inherit_rows[0]["credential_id"] == fixture.primary.distinct_credential_id
assert {
str(row["id"])
for row in surviving_non_inherit_rows
if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
} == {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
assert all(
row["model_type"] == "llm"
for row in surviving_non_inherit_rows
if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
)
assert count_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) == len(before_dry_run) - 1

File diff suppressed because it is too large Load Diff