mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
fix: legacy model_type deserialization regression (#34717)
This commit is contained in:
parent
a65e1f71b4
commit
e138523123
@ -1,6 +1,6 @@
|
||||
import enum
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
|
||||
@ -143,8 +143,14 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
# Type annotation guarantees value is str at this point
|
||||
return self._enum_class(value)
|
||||
try:
|
||||
# Type annotation guarantees value is str at this point
|
||||
return self._enum_class(value)
|
||||
except ValueError:
|
||||
value_of = getattr(self._enum_class, "value_of", None)
|
||||
if callable(value_of):
|
||||
return cast(T, value_of(value))
|
||||
raise
|
||||
|
||||
def compare_values(self, x: T | None, y: T | None) -> bool:
|
||||
if x is None or y is None:
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
@ -58,6 +59,13 @@ class _ColumnTest(_Base):
|
||||
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
|
||||
|
||||
|
||||
class _LegacyModelTypeRecord(_Base):
|
||||
__tablename__ = "enum_text_legacy_model_type_test"
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False)
|
||||
|
||||
|
||||
def _first[T](it: Iterable[T]) -> T:
|
||||
ls = list(it)
|
||||
if not ls:
|
||||
@ -201,3 +209,23 @@ class TestEnumText:
|
||||
_user = session.query(_User).where(_User.id == 1).first()
|
||||
|
||||
assert str(exc.value) == "'invalid' is not a valid _UserType"
|
||||
|
||||
def test_select_legacy_model_type_values(self, engine_with_containers: Engine):
|
||||
insertion_sql = """
|
||||
INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES
|
||||
(1, 'text-generation'),
|
||||
(2, 'embeddings'),
|
||||
(3, 'reranking');
|
||||
"""
|
||||
with Session(engine_with_containers) as session:
|
||||
session.execute(sa.text(insertion_sql))
|
||||
session.commit()
|
||||
|
||||
with Session(engine_with_containers) as session:
|
||||
records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all()
|
||||
|
||||
assert [record.model_type for record in records] == [
|
||||
ModelType.LLM,
|
||||
ModelType.TEXT_EMBEDDING,
|
||||
ModelType.RERANK,
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user