Deserializer: add Nullable fields

This commit is contained in:
Johannes Kirschbauer
2024-07-30 14:16:03 +02:00
parent 5d8fa57f23
commit 9db6cb8b6f
2 changed files with 33 additions and 4 deletions

View File

@@ -96,7 +96,7 @@ def is_union_type(type_hint: type | UnionType) -> bool:
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
if get_origin(union_type) is UnionType:
return any(issubclass(arg, target_type) for arg in get_args(union_type))
return False
return union_type == target_type
def unwrap_none_type(type_hint: type | UnionType) -> type:
@@ -121,6 +121,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
"""
Construct a field value from a type hint and a field value.
"""
if t is None and field_value:
raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}")
# If the field is another dataclass
# Field_value must be a dictionary
if is_dataclass(t) and isinstance(field_value, dict):
@@ -195,8 +197,6 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
data_field_name = field.metadata.get("alias", field.name)
# Check if the field is required
# breakpoint()
if (
field.default is dataclasses.MISSING
and field.default_factory is dataclasses.MISSING
@@ -207,7 +207,13 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
# if present in the data
if data_field_name in data:
field_value = data.get(data_field_name)
field_values[field.name] = construct_field(field_type, field_value)
if field_value is None and (
field.type is None or is_type_in_union(field.type, type(None))
):
field_values[field.name] = None
else:
field_values[field.name] = construct_field(field_type, field_value)
# Check that all required field are present.
for field_name in required:

View File

@@ -83,6 +83,29 @@ def test_simple_field_missing() -> None:
from_dict(Person, person_dict)
def test_nullable() -> None:
@dataclass
class Person:
name: None
person_dict = {
"name": None,
}
from_dict(Person, person_dict)
def test_nullable_non_exist() -> None:
@dataclass
class Person:
name: None
person_dict = {}
with pytest.raises(ClanError):
from_dict(Person, person_dict)
def test_deserialize_extensive_inventory() -> None:
# TODO: Make this an abstract test, so it doesn't break the test if the inventory changes
data = {