diff --git a/api/models/types.py b/api/models/types.py index 9ab694759f..c1d9c3845a 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,6 +1,6 @@ import enum import uuid -from typing import Any +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator @@ -143,8 +143,14 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): def process_result_value(self, value: str | None, dialect: Dialect) -> T | None: if value is None or value == "": return None - # Type annotation guarantees value is str at this point - return self._enum_class(value) + try: + # Type annotation guarantees value is str at this point + return self._enum_class(value) + except ValueError: + value_of = getattr(self._enum_class, "value_of", None) + if callable(value_of): + return cast(T, value_of(value)) + raise def compare_values(self, x: T | None, y: T | None) -> bool: if x is None or y is None: diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 9cf96c1ca7..8aec6b6acc 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -4,6 +4,7 @@ from typing import Any, NamedTuple import pytest import sqlalchemy as sa +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import exc as sa_exc from sqlalchemy import insert from sqlalchemy.engine import Connection, Engine @@ -58,6 +59,13 @@ class _ColumnTest(_Base): long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False) +class _LegacyModelTypeRecord(_Base): + __tablename__ = "enum_text_legacy_model_type_test" + + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False) + + def _first[T](it: Iterable[T]) -> T: ls = list(it) if not ls: @@ -201,3 +209,23 @@ class TestEnumText: _user = session.query(_User).where(_User.id == 1).first() assert str(exc.value) == "'invalid' is not a valid _UserType" + + def test_select_legacy_model_type_values(self, engine_with_containers: Engine): + insertion_sql = """ + INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES + (1, 'text-generation'), + (2, 'embeddings'), + (3, 'reranking'); + """ + with Session(engine_with_containers) as session: + session.execute(sa.text(insertion_sql)) + session.commit() + + with Session(engine_with_containers) as session: + records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all() + + assert [record.model_type for record in records] == [ + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, + ]