From 4b2d1b7923ce7491b66cee6864fecfd4fb15963b Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Tue, 30 Jul 2024 14:16:03 +0200 Subject: [PATCH] Deserializer: add Nullable fields --- pkgs/clan-cli/clan_cli/api/serde.py | 14 ++++++++++---- pkgs/clan-cli/tests/test_deserializers.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 605b04099..24b9786d5 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -96,7 +96,7 @@ def is_union_type(type_hint: type | UnionType) -> bool: def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool: if get_origin(union_type) is UnionType: return any(issubclass(arg, target_type) for arg in get_args(union_type)) - return False + return union_type == target_type def unwrap_none_type(type_hint: type | UnionType) -> type: @@ -121,6 +121,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any """ Construct a field value from a type hint and a field value. """ + if t is None and field_value: + raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}") # If the field is another dataclass # Field_value must be a dictionary if is_dataclass(t) and isinstance(field_value, dict): @@ -195,8 +197,6 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T: field_type: type[Any] = unwrap_none_type(field.type) # type: ignore data_field_name = field.metadata.get("alias", field.name) - # Check if the field is required - # breakpoint() if ( field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING @@ -207,7 +207,13 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T: # if present in the data if data_field_name in data: field_value = data.get(data_field_name) - field_values[field.name] = construct_field(field_type, field_value) + + if field_value is None and ( + field.type is None or is_type_in_union(field.type, type(None)) + ): + field_values[field.name] = None + else: + field_values[field.name] = construct_field(field_type, field_value) # Check that all required field are present. for field_name in required: diff --git a/pkgs/clan-cli/tests/test_deserializers.py b/pkgs/clan-cli/tests/test_deserializers.py index 7526f342e..3d6d7583d 100644 --- a/pkgs/clan-cli/tests/test_deserializers.py +++ b/pkgs/clan-cli/tests/test_deserializers.py @@ -83,6 +83,29 @@ def test_simple_field_missing() -> None: from_dict(Person, person_dict) +def test_nullable() -> None: + @dataclass + class Person: + name: None + + person_dict = { + "name": None, + } + + from_dict(Person, person_dict) + + +def test_nullable_non_exist() -> None: + @dataclass + class Person: + name: None + + person_dict = {} + + with pytest.raises(ClanError): + from_dict(Person, person_dict) + + def test_deserialize_extensive_inventory() -> None: # TODO: Make this an abstract test, so it doesn't break the test if the inventory changes data = {