Deserializer: add Nullable fields
This commit is contained in:
@@ -96,7 +96,7 @@ def is_union_type(type_hint: type | UnionType) -> bool:
|
||||
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
||||
if get_origin(union_type) is UnionType:
|
||||
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
||||
return False
|
||||
return union_type == target_type
|
||||
|
||||
|
||||
def unwrap_none_type(type_hint: type | UnionType) -> type:
|
||||
@@ -121,6 +121,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
|
||||
"""
|
||||
Construct a field value from a type hint and a field value.
|
||||
"""
|
||||
if t is None and field_value:
|
||||
raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}")
|
||||
# If the field is another dataclass
|
||||
# Field_value must be a dictionary
|
||||
if is_dataclass(t) and isinstance(field_value, dict):
|
||||
@@ -195,8 +197,6 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
|
||||
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
|
||||
data_field_name = field.metadata.get("alias", field.name)
|
||||
|
||||
# Check if the field is required
|
||||
# breakpoint()
|
||||
if (
|
||||
field.default is dataclasses.MISSING
|
||||
and field.default_factory is dataclasses.MISSING
|
||||
@@ -207,7 +207,13 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
|
||||
# if present in the data
|
||||
if data_field_name in data:
|
||||
field_value = data.get(data_field_name)
|
||||
field_values[field.name] = construct_field(field_type, field_value)
|
||||
|
||||
if field_value is None and (
|
||||
field.type is None or is_type_in_union(field.type, type(None))
|
||||
):
|
||||
field_values[field.name] = None
|
||||
else:
|
||||
field_values[field.name] = construct_field(field_type, field_value)
|
||||
|
||||
# Check that all required field are present.
|
||||
for field_name in required:
|
||||
|
||||
@@ -83,6 +83,29 @@ def test_simple_field_missing() -> None:
|
||||
from_dict(Person, person_dict)
|
||||
|
||||
|
||||
def test_nullable() -> None:
|
||||
@dataclass
|
||||
class Person:
|
||||
name: None
|
||||
|
||||
person_dict = {
|
||||
"name": None,
|
||||
}
|
||||
|
||||
from_dict(Person, person_dict)
|
||||
|
||||
|
||||
def test_nullable_non_exist() -> None:
|
||||
@dataclass
|
||||
class Person:
|
||||
name: None
|
||||
|
||||
person_dict = {}
|
||||
|
||||
with pytest.raises(ClanError):
|
||||
from_dict(Person, person_dict)
|
||||
|
||||
|
||||
def test_deserialize_extensive_inventory() -> None:
|
||||
# TODO: Make this an abstract test, so it doesn't break the test if the inventory changes
|
||||
data = {
|
||||
|
||||
Reference in New Issue
Block a user