mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
feat(api): introduce model-type migration script (#36520)
This commit is contained in:
parent
dade318f00
commit
5c5a6e83e5
@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp):
|
||||
|
||||
def create_migrations_app() -> DifyApp:
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
from extensions import ext_commands, ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
ext_commands.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@ -3,6 +3,7 @@ CLI command modules extracted from `commands.py`.
|
||||
"""
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .data_migrate import data_migrate, legacy_model_types
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
@ -44,6 +45,7 @@ __all__ = [
|
||||
"clear_orphaned_file_records",
|
||||
"convert_to_agent_apps",
|
||||
"create_tenant",
|
||||
"data_migrate",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"extract_plugins",
|
||||
@ -52,6 +54,7 @@ __all__ = [
|
||||
"fix_app_site_missing",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"legacy_model_types",
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
|
||||
169
api/commands/data_migrate.py
Normal file
169
api/commands/data_migrate.py
Normal file
@ -0,0 +1,169 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from services.legacy_model_type_migration import (
|
||||
VALID_TABLE_NAMES,
|
||||
LegacyModelTypeMigrationService,
|
||||
load_tenant_ids_from_file,
|
||||
)
|
||||
|
||||
_SUPPORTED_MODEL_TYPE_CHOICES = (
|
||||
ModelType.LLM.value,
|
||||
ModelType.TEXT_EMBEDDING.value,
|
||||
ModelType.RERANK.value,
|
||||
)
|
||||
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
|
||||
|
||||
|
||||
def _normalize_multi_value_option(
|
||||
values: tuple[str, ...],
|
||||
*,
|
||||
valid_values: tuple[str, ...],
|
||||
option_name: str,
|
||||
) -> tuple[str, ...]:
|
||||
normalized_values: list[str] = []
|
||||
seen_values: set[str] = set()
|
||||
|
||||
for value in values:
|
||||
for item in value.split(","):
|
||||
normalized_item = item.strip()
|
||||
if not normalized_item:
|
||||
continue
|
||||
if normalized_item not in valid_values:
|
||||
raise click.BadParameter(
|
||||
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
|
||||
param_hint=option_name,
|
||||
)
|
||||
if normalized_item in seen_values:
|
||||
continue
|
||||
seen_values.add(normalized_item)
|
||||
normalized_values.append(normalized_item)
|
||||
|
||||
return tuple(normalized_values)
|
||||
|
||||
|
||||
@click.group(
|
||||
"data-migrate",
|
||||
help="Online data migration commands.",
|
||||
)
|
||||
def data_migrate() -> None:
|
||||
"""Namespace for production data migration commands."""
|
||||
|
||||
|
||||
@click.command(
|
||||
"legacy-model-types",
|
||||
help=(
|
||||
"Migrate legacy provider model_type values to canonical values. "
|
||||
"Default is dry-run and emits JSON lines only. "
|
||||
"If --tables includes provider_model_credentials, the command may also update "
|
||||
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--apply",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Apply the migration. Default is dry-run.",
|
||||
)
|
||||
@click.option(
|
||||
"--tables",
|
||||
"tables",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Limit model_type migration to specific tables. Accepts comma-separated values or repeated flags. "
|
||||
"When provider_model_credentials is selected, provider_models and "
|
||||
"load_balancing_model_configs may also be updated for credential reference rewrites."
|
||||
"Default to: "
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-types",
|
||||
"model_types",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Canonical model types to migrate. Accepts comma-separated values or repeated flags. "
|
||||
"Defaults to: `llm,text-embedding,rerank`"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-id-file",
|
||||
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
|
||||
help="Optional file containing tenant ids, one per line.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
|
||||
help="Optional file path for JSON lines event logs. Defaults to stdout.",
|
||||
)
|
||||
@click.option(
|
||||
"--concurrency",
|
||||
type=click.IntRange(min=1),
|
||||
default=_DEFAULT_CONCURRENCY,
|
||||
show_default=True,
|
||||
help="Number of tenant-level worker threads to run in parallel.",
|
||||
)
|
||||
def legacy_model_types(
|
||||
apply: bool,
|
||||
tables: tuple[str, ...],
|
||||
model_types: tuple[str, ...],
|
||||
tenant_id_file: str | None,
|
||||
output: Path | None,
|
||||
concurrency: int = _DEFAULT_CONCURRENCY,
|
||||
) -> None:
|
||||
"""
|
||||
Migrate legacy provider-related model_type values and emit JSON lines events.
|
||||
"""
|
||||
|
||||
normalized_tables = _normalize_multi_value_option(
|
||||
tables,
|
||||
valid_values=VALID_TABLE_NAMES,
|
||||
option_name="--tables",
|
||||
)
|
||||
normalized_model_types = _normalize_multi_value_option(
|
||||
model_types,
|
||||
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
|
||||
option_name="--model-types",
|
||||
)
|
||||
selected_model_types = (
|
||||
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
|
||||
if normalized_model_types
|
||||
else (
|
||||
ModelType.LLM,
|
||||
ModelType.TEXT_EMBEDDING,
|
||||
ModelType.RERANK,
|
||||
)
|
||||
)
|
||||
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
|
||||
|
||||
output_context: AbstractContextManager[io.TextIOBase]
|
||||
if output is None:
|
||||
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
|
||||
else:
|
||||
try:
|
||||
output_context = output.open("w", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
|
||||
|
||||
with output_context as output_stream:
|
||||
LegacyModelTypeMigrationService(
|
||||
engine=db.engine,
|
||||
apply=apply,
|
||||
concurrency=concurrency,
|
||||
output=cast(io.TextIOBase, output_stream),
|
||||
tables=normalized_tables or None,
|
||||
model_types=selected_model_types,
|
||||
tenant_ids=tenant_ids,
|
||||
).migrate()
|
||||
|
||||
|
||||
data_migrate.add_command(legacy_model_types)
|
||||
@ -12,6 +12,7 @@ def init_app(app: DifyApp):
|
||||
clear_orphaned_file_records,
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
data_migrate,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
@ -44,6 +45,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
add_qdrant_index,
|
||||
create_tenant,
|
||||
data_migrate,
|
||||
upgrade_db,
|
||||
fix_app_site_missing,
|
||||
migrate_data_for_plugin,
|
||||
|
||||
@ -102,10 +102,7 @@ dify-trace-weave = { workspace = true }
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
|
||||
package = false
|
||||
override-dependencies = [
|
||||
"litellm>=1.83.10,<2.0.0",
|
||||
"pyarrow>=23.0.1,<24.0.0",
|
||||
]
|
||||
override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
|
||||
2464
api/services/legacy_model_type_migration.py
Normal file
2464
api/services/legacy_model_type_migration.py
Normal file
File diff suppressed because it is too large
Load Diff
1
api/tests/helpers/__init__.py
Normal file
1
api/tests/helpers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Shared test helpers for backend migration tests."""
|
||||
366
api/tests/helpers/legacy_model_type_migration.py
Normal file
366
api/tests/helpers/legacy_model_type_migration.py
Normal file
@ -0,0 +1,366 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from models.account import Tenant
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
ProviderModel,
|
||||
ProviderModelCredential,
|
||||
ProviderModelSetting,
|
||||
TenantDefaultModel,
|
||||
)
|
||||
|
||||
LEGACY_TO_CANONICAL: dict[str, str] = {
|
||||
"text-generation": "llm",
|
||||
"embeddings": "text-embedding",
|
||||
"reranking": "rerank",
|
||||
}
|
||||
UNCHANGED_MODEL_TYPES: tuple[str, ...] = ("speech2text", "moderation", "tts")
|
||||
ALL_TABLE_NAMES: tuple[str, ...] = (
|
||||
ProviderModel.__tablename__,
|
||||
TenantDefaultModel.__tablename__,
|
||||
ProviderModelSetting.__tablename__,
|
||||
LoadBalancingModelConfig.__tablename__,
|
||||
ProviderModelCredential.__tablename__,
|
||||
)
|
||||
DEFAULT_PRIMARY_TENANT_ID = "00000000-0000-0000-0000-000000000101"
|
||||
DEFAULT_SECONDARY_TENANT_ID = "00000000-0000-0000-0000-000000000202"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DirtyTenantFixture:
|
||||
tenant_id: str
|
||||
winner_credential_id: str
|
||||
loser_credential_id: str
|
||||
distinct_credential_id: str
|
||||
provider_model_id: str
|
||||
load_balancing_config_id: str
|
||||
provider_model_setting_id: str
|
||||
tenant_default_model_id: str
|
||||
embedding_provider_model_id: str
|
||||
embedding_setting_id: str
|
||||
loser_credential_name: str
|
||||
distinct_credential_name: str
|
||||
loser_encrypted_config: str
|
||||
winner_encrypted_config: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DirtyDataFixture:
|
||||
primary: DirtyTenantFixture
|
||||
secondary: DirtyTenantFixture
|
||||
|
||||
|
||||
def create_minimal_legacy_model_type_schema(engine: Engine) -> None:
|
||||
metadata = Tenant.__table__.metadata
|
||||
metadata.create_all(
|
||||
engine,
|
||||
tables=[
|
||||
Tenant.__table__,
|
||||
ProviderModel.__table__,
|
||||
TenantDefaultModel.__table__,
|
||||
ProviderModelSetting.__table__,
|
||||
LoadBalancingModelConfig.__table__,
|
||||
ProviderModelCredential.__table__,
|
||||
],
|
||||
checkfirst=True,
|
||||
)
|
||||
|
||||
|
||||
def drop_minimal_legacy_model_type_schema(engine: Engine) -> None:
|
||||
metadata = Tenant.__table__.metadata
|
||||
metadata.drop_all(
|
||||
engine,
|
||||
tables=[
|
||||
LoadBalancingModelConfig.__table__,
|
||||
ProviderModelSetting.__table__,
|
||||
TenantDefaultModel.__table__,
|
||||
ProviderModel.__table__,
|
||||
ProviderModelCredential.__table__,
|
||||
Tenant.__table__,
|
||||
],
|
||||
checkfirst=True,
|
||||
)
|
||||
|
||||
|
||||
def seed_legacy_model_type_dirty_data(
|
||||
engine: Engine,
|
||||
*,
|
||||
primary_tenant_id: str = DEFAULT_PRIMARY_TENANT_ID,
|
||||
secondary_tenant_id: str = DEFAULT_SECONDARY_TENANT_ID,
|
||||
) -> DirtyDataFixture:
|
||||
create_minimal_legacy_model_type_schema(engine)
|
||||
primary = _seed_tenant(engine, tenant_id=primary_tenant_id, provider_name="openai")
|
||||
secondary = _seed_tenant(engine, tenant_id=secondary_tenant_id, provider_name="openai")
|
||||
return DirtyDataFixture(primary=primary, secondary=secondary)
|
||||
|
||||
|
||||
def snapshot_legacy_model_type_state(engine: Engine) -> dict[str, list[dict[str, object]]]:
|
||||
snapshots: dict[str, list[dict[str, object]]] = {}
|
||||
for table_name in ALL_TABLE_NAMES:
|
||||
snapshots[table_name] = fetch_table_rows(engine, table_name)
|
||||
return snapshots
|
||||
|
||||
|
||||
def fetch_table_rows(
|
||||
engine: Engine,
|
||||
table_name: str,
|
||||
*,
|
||||
tenant_id: str | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
sql = f"SELECT * FROM {table_name}"
|
||||
params: dict[str, object] = {}
|
||||
if tenant_id is not None:
|
||||
sql += " WHERE tenant_id = :tenant_id"
|
||||
params["tenant_id"] = tenant_id
|
||||
sql += " ORDER BY id ASC"
|
||||
|
||||
with engine.begin() as conn:
|
||||
rows = conn.execute(sa.text(sql), params).mappings().all()
|
||||
|
||||
result: list[dict[str, object]] = []
|
||||
for row in rows:
|
||||
normalized = dict(row)
|
||||
for key, value in normalized.items():
|
||||
if isinstance(value, datetime):
|
||||
normalized[key] = value.isoformat()
|
||||
elif isinstance(value, uuid.UUID):
|
||||
normalized[key] = str(value)
|
||||
result.append(normalized)
|
||||
return result
|
||||
|
||||
|
||||
def fetch_model_types_for_tenant(engine: Engine, table_name: str, tenant_id: str) -> list[str]:
|
||||
rows = fetch_table_rows(engine, table_name, tenant_id=tenant_id)
|
||||
return [str(row["model_type"]) for row in rows]
|
||||
|
||||
|
||||
def assert_tenant_rows_use_only_canonical_model_types(engine: Engine, tenant_id: str) -> None:
|
||||
for table_name in ALL_TABLE_NAMES:
|
||||
model_types = fetch_model_types_for_tenant(engine, table_name, tenant_id)
|
||||
assert set(model_types) <= set(LEGACY_TO_CANONICAL.values()) | set(UNCHANGED_MODEL_TYPES), (
|
||||
table_name,
|
||||
model_types,
|
||||
)
|
||||
|
||||
|
||||
def count_rows(engine: Engine, table_name: str, *, tenant_id: str) -> int:
|
||||
with engine.begin() as conn:
|
||||
stmt = sa.text(f"SELECT COUNT(*) FROM {table_name} WHERE tenant_id = :tenant_id")
|
||||
return int(conn.execute(stmt, {"tenant_id": tenant_id}).scalar_one())
|
||||
|
||||
|
||||
def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> DirtyTenantFixture:
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
winner_credential_id = str(uuid4())
|
||||
loser_credential_id = str(uuid4())
|
||||
distinct_credential_id = str(uuid4())
|
||||
provider_model_id = str(uuid4())
|
||||
load_balancing_config_id = str(uuid4())
|
||||
provider_model_setting_id = str(uuid4())
|
||||
tenant_default_model_id = str(uuid4())
|
||||
embedding_provider_model_id = str(uuid4())
|
||||
embedding_setting_id = str(uuid4())
|
||||
|
||||
loser_credential_name = f"{tenant_id}-shared"
|
||||
distinct_credential_name = f"{tenant_id}-distinct"
|
||||
winner_encrypted_config = json.dumps({"api_key": f"{tenant_id}-winner"})
|
||||
loser_encrypted_config = json.dumps({"api_key": f"{tenant_id}-loser"})
|
||||
distinct_encrypted_config = json.dumps({"api_key": f"{tenant_id}-distinct"})
|
||||
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
Tenant.__table__.insert().values(
|
||||
id=tenant_id,
|
||||
name=f"Tenant {tenant_id}",
|
||||
plan="basic",
|
||||
status="normal",
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO provider_model_credentials
|
||||
(
|
||||
id, tenant_id, provider_name, model_name,
|
||||
model_type, credential_name, encrypted_config,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
:winner_id, :tenant_id, :provider_name, 'gpt-4o-mini',
|
||||
'llm', :shared_name, :winner_config,
|
||||
:created_at, :winner_updated_at
|
||||
),
|
||||
(
|
||||
:loser_id, :tenant_id, :provider_name, 'gpt-4o-mini',
|
||||
'text-generation', :shared_name, :loser_config,
|
||||
:created_at, :loser_updated_at
|
||||
),
|
||||
(
|
||||
:distinct_id, :tenant_id, :provider_name, 'gpt-4o-mini',
|
||||
'text-generation', :distinct_name, :distinct_config,
|
||||
:created_at, :distinct_updated_at
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"winner_id": winner_credential_id,
|
||||
"loser_id": loser_credential_id,
|
||||
"distinct_id": distinct_credential_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": provider_name,
|
||||
"shared_name": loser_credential_name,
|
||||
"distinct_name": distinct_credential_name,
|
||||
"winner_config": winner_encrypted_config,
|
||||
"loser_config": loser_encrypted_config,
|
||||
"distinct_config": distinct_encrypted_config,
|
||||
"created_at": now - timedelta(days=2),
|
||||
"winner_updated_at": now,
|
||||
"loser_updated_at": now - timedelta(days=1),
|
||||
"distinct_updated_at": now - timedelta(hours=12),
|
||||
},
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO provider_models
|
||||
(
|
||||
id, tenant_id, provider_name, model_name,
|
||||
model_type, credential_id, is_valid,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
:provider_model_id, :tenant_id, :provider_name, 'gpt-4o-mini',
|
||||
'text-generation', :loser_id, :is_valid,
|
||||
:created_at, :updated_at
|
||||
),
|
||||
(
|
||||
:embedding_provider_model_id, :tenant_id, :provider_name, 'text-embedding-3-large',
|
||||
'embeddings', NULL, :is_valid,
|
||||
:created_at, :updated_at
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"provider_model_id": provider_model_id,
|
||||
"embedding_provider_model_id": embedding_provider_model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": provider_name,
|
||||
"loser_id": loser_credential_id,
|
||||
"is_valid": True,
|
||||
"created_at": now - timedelta(days=2),
|
||||
"updated_at": now - timedelta(hours=6),
|
||||
},
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tenant_default_models
|
||||
(id, tenant_id, provider_name, model_name, model_type, created_at, updated_at)
|
||||
VALUES
|
||||
(
|
||||
:tenant_default_model_id, :tenant_id, :provider_name, 'gpt-4o-mini',
|
||||
'text-generation', :created_at, :updated_at
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"tenant_default_model_id": tenant_default_model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": provider_name,
|
||||
"created_at": now - timedelta(days=2),
|
||||
"updated_at": now - timedelta(hours=4),
|
||||
},
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO provider_model_settings
|
||||
(
|
||||
id, tenant_id, provider_name, model_name,
|
||||
model_type, enabled, load_balancing_enabled,
|
||||
created_at, updated_at
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
:provider_model_setting_id, :tenant_id, :provider_name, 'gpt-4o-mini',
|
||||
'text-generation', :enabled, :load_balancing_enabled,
|
||||
:created_at, :updated_at
|
||||
),
|
||||
(
|
||||
:embedding_setting_id, :tenant_id, :provider_name, 'text-embedding-3-large',
|
||||
'embeddings', :enabled, :embedding_load_balancing_enabled,
|
||||
:created_at, :updated_at
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"provider_model_setting_id": provider_model_setting_id,
|
||||
"embedding_setting_id": embedding_setting_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": provider_name,
|
||||
"enabled": True,
|
||||
"load_balancing_enabled": True,
|
||||
"embedding_load_balancing_enabled": False,
|
||||
"created_at": now - timedelta(days=2),
|
||||
"updated_at": now - timedelta(hours=3),
|
||||
},
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO load_balancing_model_configs
|
||||
(
|
||||
id, tenant_id, provider_name, model_name, model_type,
|
||||
name, encrypted_config, credential_id, credential_source_type,
|
||||
enabled, created_at, updated_at
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
:load_balancing_config_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation',
|
||||
:lb_name, :loser_config, :loser_id, :credential_source_type,
|
||||
:enabled, :created_at, :updated_at
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"load_balancing_config_id": load_balancing_config_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": provider_name,
|
||||
"lb_name": loser_credential_name,
|
||||
"loser_config": loser_encrypted_config,
|
||||
"loser_id": loser_credential_id,
|
||||
"credential_source_type": CredentialSourceType.CUSTOM_MODEL.value,
|
||||
"enabled": True,
|
||||
"created_at": now - timedelta(days=2),
|
||||
"updated_at": now - timedelta(hours=2),
|
||||
},
|
||||
)
|
||||
|
||||
return DirtyTenantFixture(
|
||||
tenant_id=tenant_id,
|
||||
winner_credential_id=winner_credential_id,
|
||||
loser_credential_id=loser_credential_id,
|
||||
distinct_credential_id=distinct_credential_id,
|
||||
provider_model_id=provider_model_id,
|
||||
load_balancing_config_id=load_balancing_config_id,
|
||||
provider_model_setting_id=provider_model_setting_id,
|
||||
tenant_default_model_id=tenant_default_model_id,
|
||||
embedding_provider_model_id=embedding_provider_model_id,
|
||||
embedding_setting_id=embedding_setting_id,
|
||||
loser_credential_name=loser_credential_name,
|
||||
distinct_credential_name=distinct_credential_name,
|
||||
loser_encrypted_config=loser_encrypted_config,
|
||||
winner_encrypted_config=winner_encrypted_config,
|
||||
)
|
||||
82
api/tests/seed_legacy_model_type_dirty_data.py
Normal file
82
api/tests/seed_legacy_model_type_dirty_data.py
Normal file
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
API_PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(API_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(API_PROJECT_ROOT))
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from tests.helpers.legacy_model_type_migration import (
|
||||
DEFAULT_PRIMARY_TENANT_ID,
|
||||
DEFAULT_SECONDARY_TENANT_ID,
|
||||
create_minimal_legacy_model_type_schema,
|
||||
seed_legacy_model_type_dirty_data,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Seed dirty legacy model_type rows for manual migration experiments. "
|
||||
"Example: uv run --project api python api/tests/seed_legacy_model_type_dirty_data.py "
|
||||
"--db-url postgresql://postgres:postgres@127.0.0.1:5432/dify"
|
||||
)
|
||||
)
|
||||
parser.add_argument("--db-url", required=True, help="SQLAlchemy database URL for the target database.")
|
||||
parser.add_argument(
|
||||
"--primary-tenant-id",
|
||||
default=DEFAULT_PRIMARY_TENANT_ID,
|
||||
help="Tenant that will contain the main conflict scenario.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secondary-tenant-id",
|
||||
default=DEFAULT_SECONDARY_TENANT_ID,
|
||||
help="Tenant used to verify tenant filtering behavior.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create-minimal-schema",
|
||||
action="store_true",
|
||||
help="Create the minimal tables needed for the seed when running against an empty scratch database.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
engine = sa.create_engine(args.db_url)
|
||||
try:
|
||||
if args.create_minimal_schema:
|
||||
create_minimal_legacy_model_type_schema(engine)
|
||||
|
||||
fixture = seed_legacy_model_type_dirty_data(
|
||||
engine,
|
||||
primary_tenant_id=args.primary_tenant_id,
|
||||
secondary_tenant_id=args.secondary_tenant_id,
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"primary_tenant_id": fixture.primary.tenant_id,
|
||||
"secondary_tenant_id": fixture.secondary.tenant_id,
|
||||
"winner_credential_id": fixture.primary.winner_credential_id,
|
||||
"loser_credential_id": fixture.primary.loser_credential_id,
|
||||
"provider_model_id": fixture.primary.provider_model_id,
|
||||
"load_balancing_config_id": fixture.primary.load_balancing_config_id,
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -0,0 +1,408 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from tests.helpers.legacy_model_type_migration import (
|
||||
assert_tenant_rows_use_only_canonical_model_types,
|
||||
count_rows,
|
||||
fetch_table_rows,
|
||||
seed_legacy_model_type_dirty_data,
|
||||
)
|
||||
|
||||
|
||||
def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]:
|
||||
return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _json_key(value: object) -> str:
|
||||
return json.dumps(value, sort_keys=True)
|
||||
|
||||
|
||||
def _lb_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]:
|
||||
signatures: set[tuple[object, ...]] = set()
|
||||
for line in lines:
|
||||
attrs = line.get("attrs")
|
||||
if not isinstance(attrs, dict):
|
||||
continue
|
||||
if attrs.get("table_name") != "load_balancing_model_configs":
|
||||
continue
|
||||
event = line.get("event")
|
||||
if event == "row_updated":
|
||||
signatures.add(
|
||||
(
|
||||
event,
|
||||
attrs.get("id"),
|
||||
_json_key(attrs.get("old_values")),
|
||||
_json_key(attrs.get("new_values")),
|
||||
)
|
||||
)
|
||||
elif event == "row_deleted":
|
||||
signatures.add(
|
||||
(
|
||||
event,
|
||||
attrs.get("id"),
|
||||
attrs.get("merge_winner_id"),
|
||||
)
|
||||
)
|
||||
elif event == "group_processed":
|
||||
signatures.add(
|
||||
(
|
||||
event,
|
||||
attrs.get("table_name"),
|
||||
_json_key(attrs.get("business_key")),
|
||||
tuple(attrs.get("group_row_ids", [])),
|
||||
)
|
||||
)
|
||||
return signatures
|
||||
|
||||
|
||||
def _insert_load_balancing_model_config(
|
||||
engine: sa.Engine,
|
||||
*,
|
||||
row_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
name: str,
|
||||
encrypted_config: str,
|
||||
credential_id: str,
|
||||
enabled: bool,
|
||||
created_at: datetime,
|
||||
updated_at: datetime,
|
||||
) -> None:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO load_balancing_model_configs
|
||||
(
|
||||
id, tenant_id, provider_name, model_name, model_type, name,
|
||||
encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
:id, :tenant_id, :provider_name, :model_name, :model_type, :name,
|
||||
:encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": row_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": provider_name,
|
||||
"model_name": model_name,
|
||||
"model_type": model_type,
|
||||
"name": name,
|
||||
"encrypted_config": encrypted_config,
|
||||
"credential_id": credential_id,
|
||||
"credential_source_type": "custom_model",
|
||||
"enabled": enabled,
|
||||
"created_at": created_at,
|
||||
"updated_at": updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def migration_module():
|
||||
try:
|
||||
return importlib.import_module("services.legacy_model_type_migration")
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path
|
||||
pytest.fail(
|
||||
"services.legacy_model_type_migration is missing. "
|
||||
"Implement LegacyModelTypeMigrationService before running these tests."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(params=("postgresql", "mysql"), scope="session")
|
||||
def container_engine(request: pytest.FixtureRequest) -> Generator[tuple[str, sa.Engine], None, None]:
|
||||
backend_name = request.param
|
||||
if backend_name == "postgresql":
|
||||
testcontainers_postgres = pytest.importorskip("testcontainers.postgres")
|
||||
container = testcontainers_postgres.PostgresContainer("postgres:15-alpine")
|
||||
else:
|
||||
testcontainers_mysql = pytest.importorskip("testcontainers.mysql")
|
||||
container = testcontainers_mysql.MySqlContainer("mysql:8.0")
|
||||
|
||||
container.start()
|
||||
raw_url = container.get_connection_url()
|
||||
engine_url = raw_url.replace("mysql://", "mysql+pymysql://", 1)
|
||||
engine = sa.create_engine(engine_url)
|
||||
|
||||
try:
|
||||
yield backend_name, engine
|
||||
finally:
|
||||
engine.dispose()
|
||||
container.stop()
|
||||
|
||||
|
||||
def test_legacy_model_type_migration_end_to_end_across_supported_backends(
|
||||
migration_module,
|
||||
container_engine: tuple[str, sa.Engine],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
backend_name, engine = container_engine
|
||||
helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration")
|
||||
helper_module.drop_minimal_legacy_model_type_schema(engine)
|
||||
fixture = seed_legacy_model_type_dirty_data(engine)
|
||||
|
||||
deleted_cache_keys: list[str] = []
|
||||
|
||||
def _record_delete(self) -> None:
|
||||
deleted_cache_keys.append(self.cache_key)
|
||||
|
||||
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete)
|
||||
|
||||
dry_run_output = io.StringIO()
|
||||
migration_module.LegacyModelTypeMigrationService(
|
||||
engine=engine,
|
||||
apply=False,
|
||||
output=dry_run_output,
|
||||
tenant_ids=(fixture.primary.tenant_id,),
|
||||
).migrate()
|
||||
|
||||
assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 3
|
||||
assert deleted_cache_keys == []
|
||||
|
||||
apply_output = io.StringIO()
|
||||
migration_module.LegacyModelTypeMigrationService(
|
||||
engine=engine,
|
||||
apply=True,
|
||||
output=apply_output,
|
||||
tenant_ids=(fixture.primary.tenant_id,),
|
||||
).migrate()
|
||||
first_apply_state = {
|
||||
table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id)
|
||||
for table_name in (
|
||||
"provider_models",
|
||||
"tenant_default_models",
|
||||
"provider_model_settings",
|
||||
"load_balancing_model_configs",
|
||||
"provider_model_credentials",
|
||||
)
|
||||
}
|
||||
|
||||
assert_tenant_rows_use_only_canonical_model_types(engine, fixture.primary.tenant_id)
|
||||
assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 2
|
||||
provider_model_row = next(
|
||||
row for row in first_apply_state["provider_models"] if row["id"] == fixture.primary.provider_model_id
|
||||
)
|
||||
assert provider_model_row["credential_id"] == fixture.primary.winner_credential_id
|
||||
credential_ids = {str(row["id"]) for row in first_apply_state["provider_model_credentials"]}
|
||||
assert credential_ids == {
|
||||
fixture.primary.winner_credential_id,
|
||||
fixture.primary.distinct_credential_id,
|
||||
}
|
||||
lb_row = next(
|
||||
row
|
||||
for row in first_apply_state["load_balancing_model_configs"]
|
||||
if row["id"] == fixture.primary.load_balancing_config_id
|
||||
)
|
||||
assert lb_row["credential_id"] == fixture.primary.winner_credential_id
|
||||
assert lb_row["encrypted_config"] == fixture.primary.winner_encrypted_config
|
||||
assert deleted_cache_keys, f"{backend_name} apply run should clear cache keys"
|
||||
|
||||
migration_module.LegacyModelTypeMigrationService(
|
||||
engine=engine,
|
||||
apply=True,
|
||||
output=io.StringIO(),
|
||||
tenant_ids=(fixture.primary.tenant_id,),
|
||||
).migrate()
|
||||
second_apply_state = {
|
||||
table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id)
|
||||
for table_name in first_apply_state
|
||||
}
|
||||
assert second_apply_state == first_apply_state
|
||||
|
||||
|
||||
def test_load_balancing_inherit_deduplication_is_applied_consistently_across_supported_backends(
|
||||
migration_module,
|
||||
container_engine: tuple[str, sa.Engine],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_, engine = container_engine
|
||||
helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration")
|
||||
helper_module.drop_minimal_legacy_model_type_schema(engine)
|
||||
fixture = seed_legacy_model_type_dirty_data(engine)
|
||||
|
||||
tenant_id = fixture.primary.tenant_id
|
||||
older_inherit_row_id = "00000000-0000-0000-0000-00000000ee01"
|
||||
newer_inherit_row_id = "00000000-0000-0000-0000-00000000ee02"
|
||||
canonical_non_inherit_row_id = "00000000-0000-0000-0000-00000000ee03"
|
||||
created_at = datetime(2025, 1, 1, 8, 0, 0)
|
||||
|
||||
_insert_load_balancing_model_config(
|
||||
engine,
|
||||
row_id=older_inherit_row_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="llm",
|
||||
name="__inherit__",
|
||||
encrypted_config='{"api_key":"older-inherit"}',
|
||||
credential_id=fixture.primary.winner_credential_id,
|
||||
enabled=True,
|
||||
created_at=created_at,
|
||||
updated_at=created_at + timedelta(minutes=15),
|
||||
)
|
||||
_insert_load_balancing_model_config(
|
||||
engine,
|
||||
row_id=newer_inherit_row_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config='{"api_key":"newer-inherit"}',
|
||||
credential_id=fixture.primary.distinct_credential_id,
|
||||
enabled=True,
|
||||
created_at=created_at,
|
||||
updated_at=created_at + timedelta(minutes=30),
|
||||
)
|
||||
_insert_load_balancing_model_config(
|
||||
engine,
|
||||
row_id=canonical_non_inherit_row_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="llm",
|
||||
name=f"{tenant_id}-second-shared",
|
||||
encrypted_config='{"api_key":"non-inherit-canonical"}',
|
||||
credential_id=fixture.primary.distinct_credential_id,
|
||||
enabled=True,
|
||||
created_at=created_at,
|
||||
updated_at=created_at + timedelta(minutes=45),
|
||||
)
|
||||
|
||||
before_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
|
||||
deleted_cache_keys: list[str] = []
|
||||
|
||||
def _record_delete(self) -> None:
|
||||
deleted_cache_keys.append(self.cache_key)
|
||||
|
||||
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete)
|
||||
|
||||
dry_run_output = io.StringIO()
|
||||
migration_module.LegacyModelTypeMigrationService(
|
||||
engine=engine,
|
||||
apply=False,
|
||||
output=dry_run_output,
|
||||
tables=("load_balancing_model_configs",),
|
||||
model_types=(migration_module.ModelType.LLM,),
|
||||
tenant_ids=(tenant_id,),
|
||||
).migrate()
|
||||
|
||||
after_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
|
||||
dry_run_lines = _parse_json_lines(dry_run_output)
|
||||
dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")]
|
||||
dry_run_row_updates = {
|
||||
str(attrs["id"])
|
||||
for line in dry_run_lines
|
||||
if line.get("event") == "row_updated"
|
||||
and isinstance((attrs := line.get("attrs")), dict)
|
||||
and attrs.get("table_name") == "load_balancing_model_configs"
|
||||
}
|
||||
dry_run_row_deletes = {
|
||||
str(attrs["id"])
|
||||
for line in dry_run_lines
|
||||
if line.get("event") == "row_deleted"
|
||||
and isinstance((attrs := line.get("attrs")), dict)
|
||||
and attrs.get("table_name") == "load_balancing_model_configs"
|
||||
}
|
||||
dry_run_group_processed = [
|
||||
attrs
|
||||
for line in dry_run_lines
|
||||
if line.get("event") == "group_processed"
|
||||
and isinstance((attrs := line.get("attrs")), dict)
|
||||
and attrs.get("table_name") == "load_balancing_model_configs"
|
||||
]
|
||||
|
||||
assert after_dry_run == before_dry_run
|
||||
assert deleted_cache_keys == []
|
||||
assert dry_run_row_deletes == {older_inherit_row_id}
|
||||
assert dry_run_row_updates == {
|
||||
fixture.primary.load_balancing_config_id,
|
||||
newer_inherit_row_id,
|
||||
}
|
||||
assert canonical_non_inherit_row_id not in dry_run_row_updates
|
||||
assert "cache_delete_planned" in dry_run_cache_events
|
||||
assert "cache_deleted" not in dry_run_cache_events
|
||||
assert len(dry_run_group_processed) == 1
|
||||
assert dry_run_group_processed[0]["table_name"] == "load_balancing_model_configs"
|
||||
assert dry_run_group_processed[0]["business_key"] == {
|
||||
"tenant_id": tenant_id,
|
||||
"provider_name": "openai",
|
||||
"model_name": "gpt-4o-mini",
|
||||
"model_type": "llm",
|
||||
}
|
||||
assert set(dry_run_group_processed[0]["group_row_ids"]) == {
|
||||
older_inherit_row_id,
|
||||
newer_inherit_row_id,
|
||||
}
|
||||
|
||||
apply_output = io.StringIO()
|
||||
migration_module.LegacyModelTypeMigrationService(
|
||||
engine=engine,
|
||||
apply=True,
|
||||
output=apply_output,
|
||||
tables=("load_balancing_model_configs",),
|
||||
model_types=(migration_module.ModelType.LLM,),
|
||||
tenant_ids=(tenant_id,),
|
||||
).migrate()
|
||||
|
||||
apply_lines = _parse_json_lines(apply_output)
|
||||
apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")]
|
||||
apply_group_processed = [
|
||||
attrs
|
||||
for line in apply_lines
|
||||
if line.get("event") == "group_processed"
|
||||
and isinstance((attrs := line.get("attrs")), dict)
|
||||
and attrs.get("table_name") == "load_balancing_model_configs"
|
||||
]
|
||||
assert _lb_processing_signatures(apply_lines) == _lb_processing_signatures(dry_run_lines)
|
||||
assert "cache_deleted" in apply_cache_events
|
||||
assert deleted_cache_keys
|
||||
assert len(apply_group_processed) == len(dry_run_group_processed)
|
||||
assert [
|
||||
(
|
||||
attrs["table_name"],
|
||||
_json_key(attrs["business_key"]),
|
||||
tuple(attrs["group_row_ids"]),
|
||||
)
|
||||
for attrs in apply_group_processed
|
||||
] == [
|
||||
(
|
||||
attrs["table_name"],
|
||||
_json_key(attrs["business_key"]),
|
||||
tuple(attrs["group_row_ids"]),
|
||||
)
|
||||
for attrs in dry_run_group_processed
|
||||
]
|
||||
|
||||
lb_rows = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
|
||||
surviving_inherit_rows = [row for row in lb_rows if row["name"] == "__inherit__"]
|
||||
surviving_non_inherit_rows = [row for row in lb_rows if row["name"] != "__inherit__"]
|
||||
|
||||
assert {str(row["id"]) for row in surviving_inherit_rows} == {newer_inherit_row_id}
|
||||
assert surviving_inherit_rows[0]["model_type"] == "llm"
|
||||
assert surviving_inherit_rows[0]["credential_id"] == fixture.primary.distinct_credential_id
|
||||
|
||||
assert {
|
||||
str(row["id"])
|
||||
for row in surviving_non_inherit_rows
|
||||
if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
|
||||
} == {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
|
||||
assert all(
|
||||
row["model_type"] == "llm"
|
||||
for row in surviving_non_inherit_rows
|
||||
if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
|
||||
)
|
||||
assert count_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) == len(before_dry_run) - 1
|
||||
2025
api/tests/unit_tests/commands/test_legacy_model_type_migration.py
Normal file
2025
api/tests/unit_tests/commands/test_legacy_model_type_migration.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user