fix: legacy model_type deserialization regression (#34717)

This commit is contained in:
Will 2026-04-08 13:08:12 +08:00 committed by GitHub
parent a65e1f71b4
commit e138523123
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 3 deletions

View File

@ -1,6 +1,6 @@
import enum
import uuid
from typing import Any
from typing import Any, cast
import sqlalchemy as sa
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
@ -143,8 +143,14 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
if value is None or value == "":
return None
# Type annotation guarantees value is str at this point
return self._enum_class(value)
try:
# Type annotation guarantees value is str at this point
return self._enum_class(value)
except ValueError:
value_of = getattr(self._enum_class, "value_of", None)
if callable(value_of):
return cast(T, value_of(value))
raise
def compare_values(self, x: T | None, y: T | None) -> bool:
if x is None or y is None:

View File

@ -4,6 +4,7 @@ from typing import Any, NamedTuple
import pytest
import sqlalchemy as sa
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import exc as sa_exc
from sqlalchemy import insert
from sqlalchemy.engine import Connection, Engine
@ -58,6 +59,13 @@ class _ColumnTest(_Base):
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
class _LegacyModelTypeRecord(_Base):
__tablename__ = "enum_text_legacy_model_type_test"
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False)
def _first[T](it: Iterable[T]) -> T:
ls = list(it)
if not ls:
@ -201,3 +209,23 @@ class TestEnumText:
_user = session.query(_User).where(_User.id == 1).first()
assert str(exc.value) == "'invalid' is not a valid _UserType"
def test_select_legacy_model_type_values(self, engine_with_containers: Engine):
insertion_sql = """
INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES
(1, 'text-generation'),
(2, 'embeddings'),
(3, 'reranking');
"""
with Session(engine_with_containers) as session:
session.execute(sa.text(insertion_sql))
session.commit()
with Session(engine_with_containers) as session:
records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all()
assert [record.model_type for record in records] == [
ModelType.LLM,
ModelType.TEXT_EMBEDDING,
ModelType.RERANK,
]