Deserializer: add Nullable fields
This commit is contained in:
@@ -96,7 +96,7 @@ def is_union_type(type_hint: type | UnionType) -> bool:
|
|||||||
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
||||||
if get_origin(union_type) is UnionType:
|
if get_origin(union_type) is UnionType:
|
||||||
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
||||||
return False
|
return union_type == target_type
|
||||||
|
|
||||||
|
|
||||||
def unwrap_none_type(type_hint: type | UnionType) -> type:
|
def unwrap_none_type(type_hint: type | UnionType) -> type:
|
||||||
@@ -121,6 +121,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
|
|||||||
"""
|
"""
|
||||||
Construct a field value from a type hint and a field value.
|
Construct a field value from a type hint and a field value.
|
||||||
"""
|
"""
|
||||||
|
if t is None and field_value:
|
||||||
|
raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}")
|
||||||
# If the field is another dataclass
|
# If the field is another dataclass
|
||||||
# Field_value must be a dictionary
|
# Field_value must be a dictionary
|
||||||
if is_dataclass(t) and isinstance(field_value, dict):
|
if is_dataclass(t) and isinstance(field_value, dict):
|
||||||
@@ -195,8 +197,6 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
|
|||||||
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
|
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
|
||||||
data_field_name = field.metadata.get("alias", field.name)
|
data_field_name = field.metadata.get("alias", field.name)
|
||||||
|
|
||||||
# Check if the field is required
|
|
||||||
# breakpoint()
|
|
||||||
if (
|
if (
|
||||||
field.default is dataclasses.MISSING
|
field.default is dataclasses.MISSING
|
||||||
and field.default_factory is dataclasses.MISSING
|
and field.default_factory is dataclasses.MISSING
|
||||||
@@ -207,7 +207,13 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
|
|||||||
# if present in the data
|
# if present in the data
|
||||||
if data_field_name in data:
|
if data_field_name in data:
|
||||||
field_value = data.get(data_field_name)
|
field_value = data.get(data_field_name)
|
||||||
field_values[field.name] = construct_field(field_type, field_value)
|
|
||||||
|
if field_value is None and (
|
||||||
|
field.type is None or is_type_in_union(field.type, type(None))
|
||||||
|
):
|
||||||
|
field_values[field.name] = None
|
||||||
|
else:
|
||||||
|
field_values[field.name] = construct_field(field_type, field_value)
|
||||||
|
|
||||||
# Check that all required field are present.
|
# Check that all required field are present.
|
||||||
for field_name in required:
|
for field_name in required:
|
||||||
|
|||||||
@@ -83,6 +83,29 @@ def test_simple_field_missing() -> None:
|
|||||||
from_dict(Person, person_dict)
|
from_dict(Person, person_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nullable() -> None:
|
||||||
|
@dataclass
|
||||||
|
class Person:
|
||||||
|
name: None
|
||||||
|
|
||||||
|
person_dict = {
|
||||||
|
"name": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
from_dict(Person, person_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nullable_non_exist() -> None:
|
||||||
|
@dataclass
|
||||||
|
class Person:
|
||||||
|
name: None
|
||||||
|
|
||||||
|
person_dict = {}
|
||||||
|
|
||||||
|
with pytest.raises(ClanError):
|
||||||
|
from_dict(Person, person_dict)
|
||||||
|
|
||||||
|
|
||||||
def test_deserialize_extensive_inventory() -> None:
|
def test_deserialize_extensive_inventory() -> None:
|
||||||
# TODO: Make this an abstract test, so it doesn't break the test if the inventory changes
|
# TODO: Make this an abstract test, so it doesn't break the test if the inventory changes
|
||||||
data = {
|
data = {
|
||||||
|
|||||||
Reference in New Issue
Block a user