Serde: fix enum type conversion, ensure roundtrip stability
This commit is contained in:
@@ -233,16 +233,18 @@ def construct_value(
|
|||||||
|
|
||||||
# Enums
|
# Enums
|
||||||
if origin is Enum:
|
if origin is Enum:
|
||||||
if field_value not in origin.__members__:
|
try:
|
||||||
msg = f"Expected one of {', '.join(origin.__members__)}, got {field_value}"
|
return t(field_value) # type: ignore
|
||||||
raise ClanError(msg, location=f"{loc}")
|
except ValueError:
|
||||||
return origin.__members__[field_value] # type: ignore
|
msg = f"Expected one of {', '.join(str(origin))}, got {field_value}"
|
||||||
|
raise ClanError(msg, location=f"{loc}") from ValueError
|
||||||
|
|
||||||
if isinstance(t, type) and issubclass(t, Enum):
|
if isinstance(t, type) and issubclass(t, Enum):
|
||||||
if field_value not in t.__members__:
|
try:
|
||||||
|
return t(field_value) # type: ignore
|
||||||
|
except ValueError:
|
||||||
msg = f"Expected one of {', '.join(t.__members__)}, got {field_value}"
|
msg = f"Expected one of {', '.join(t.__members__)}, got {field_value}"
|
||||||
raise ClanError(msg, location=f"{loc}")
|
raise ClanError(msg, location=f"{loc}") from ValueError
|
||||||
return t.__members__[field_value] # type: ignore
|
|
||||||
|
|
||||||
if origin is Annotated:
|
if origin is Annotated:
|
||||||
(base_type,) = get_args(t)
|
(base_type,) = get_args(t)
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from typing import (
|
|||||||
is_typeddict,
|
is_typeddict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from clan_cli.api.serde import dataclass_to_dict
|
||||||
|
|
||||||
|
|
||||||
class JSchemaTypeError(Exception):
|
class JSchemaTypeError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -257,7 +259,8 @@ def type_to_dict(
|
|||||||
if type(t) is EnumType:
|
if type(t) is EnumType:
|
||||||
return {
|
return {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": list(t.__members__),
|
# Construct every enum value and use the same method as the serde module for converting it into the same literal string
|
||||||
|
"enum": [dataclass_to_dict(t(value)) for value in t], # type: ignore
|
||||||
}
|
}
|
||||||
if t is Any:
|
if t is Any:
|
||||||
msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}"
|
msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}"
|
||||||
|
|||||||
@@ -266,3 +266,31 @@ def test_literal_field() -> None:
|
|||||||
with pytest.raises(ClanError):
|
with pytest.raises(ClanError):
|
||||||
# Not a valid value
|
# Not a valid value
|
||||||
from_dict(Person, {"name": "open"})
|
from_dict(Person, {"name": "open"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_roundtrip() -> None:
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class MyEnum(Enum):
|
||||||
|
FOO = "abc"
|
||||||
|
BAR = 2
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Person:
|
||||||
|
name: MyEnum
|
||||||
|
|
||||||
|
# Both are equivalent
|
||||||
|
data = {"name": "abc"} # JSON Representation
|
||||||
|
expected = Person(name=MyEnum.FOO) # Data representation
|
||||||
|
|
||||||
|
assert from_dict(Person, data) == expected
|
||||||
|
|
||||||
|
assert dataclass_to_dict(expected) == data
|
||||||
|
|
||||||
|
# Same test for integer values
|
||||||
|
data2 = {"name": 2} # JSON Representation
|
||||||
|
expected2 = Person(name=MyEnum.BAR) # Data representation
|
||||||
|
|
||||||
|
assert from_dict(Person, data2) == expected2
|
||||||
|
|
||||||
|
assert dataclass_to_dict(expected2) == data2
|
||||||
|
|||||||
Reference in New Issue
Block a user