Serde: fix enum type conversion, ensure roundtrip stability
This commit is contained in:
@@ -233,16 +233,18 @@ def construct_value(
|
||||
|
||||
# Enums
|
||||
if origin is Enum:
|
||||
if field_value not in origin.__members__:
|
||||
msg = f"Expected one of {', '.join(origin.__members__)}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
return origin.__members__[field_value] # type: ignore
|
||||
try:
|
||||
return t(field_value) # type: ignore
|
||||
except ValueError:
|
||||
msg = f"Expected one of {', '.join(str(origin))}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}") from ValueError
|
||||
|
||||
if isinstance(t, type) and issubclass(t, Enum):
|
||||
if field_value not in t.__members__:
|
||||
try:
|
||||
return t(field_value) # type: ignore
|
||||
except ValueError:
|
||||
msg = f"Expected one of {', '.join(t.__members__)}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
return t.__members__[field_value] # type: ignore
|
||||
raise ClanError(msg, location=f"{loc}") from ValueError
|
||||
|
||||
if origin is Annotated:
|
||||
(base_type,) = get_args(t)
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import (
|
||||
is_typeddict,
|
||||
)
|
||||
|
||||
from clan_cli.api.serde import dataclass_to_dict
|
||||
|
||||
|
||||
class JSchemaTypeError(Exception):
|
||||
pass
|
||||
@@ -257,7 +259,8 @@ def type_to_dict(
|
||||
if type(t) is EnumType:
|
||||
return {
|
||||
"type": "string",
|
||||
"enum": list(t.__members__),
|
||||
# Construct every enum value and use the same method as the serde module for converting it into the same literal string
|
||||
"enum": [dataclass_to_dict(t(value)) for value in t], # type: ignore
|
||||
}
|
||||
if t is Any:
|
||||
msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}"
|
||||
|
||||
@@ -266,3 +266,31 @@ def test_literal_field() -> None:
|
||||
with pytest.raises(ClanError):
|
||||
# Not a valid value
|
||||
from_dict(Person, {"name": "open"})
|
||||
|
||||
|
||||
def test_enum_roundtrip() -> None:
|
||||
from enum import Enum
|
||||
|
||||
class MyEnum(Enum):
|
||||
FOO = "abc"
|
||||
BAR = 2
|
||||
|
||||
@dataclass
|
||||
class Person:
|
||||
name: MyEnum
|
||||
|
||||
# Both are equivalent
|
||||
data = {"name": "abc"} # JSON Representation
|
||||
expected = Person(name=MyEnum.FOO) # Data representation
|
||||
|
||||
assert from_dict(Person, data) == expected
|
||||
|
||||
assert dataclass_to_dict(expected) == data
|
||||
|
||||
# Same test for integer values
|
||||
data2 = {"name": 2} # JSON Representation
|
||||
expected2 = Person(name=MyEnum.BAR) # Data representation
|
||||
|
||||
assert from_dict(Person, data2) == expected2
|
||||
|
||||
assert dataclass_to_dict(expected2) == data2
|
||||
|
||||
Reference in New Issue
Block a user