Serde: extend deserializer to accept anything

This commit is contained in:
Johannes Kirschbauer
2024-08-15 15:03:52 +02:00
parent ce560e05cd
commit 4940767fcc
3 changed files with 37 additions and 12 deletions

View File

@@ -86,6 +86,7 @@ def dataclass_to_dict(obj: Any, *, use_alias: bool = True) -> Any:
T = TypeVar("T", bound=dataclass) # type: ignore
G = TypeVar("G") # type: ignore
def is_union_type(type_hint: type | UnionType) -> bool:
@@ -120,7 +121,7 @@ def unwrap_none_type(type_hint: type | UnionType) -> type:
JsonValue = str | float | dict[str, Any] | list[Any] | None
def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any:
def construct_value(t: type, field_value: JsonValue, loc: list[str] = []) -> Any:
"""
Construct a field value from a type hint and a field value.
"""
@@ -129,7 +130,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
# If the field is another dataclass
# Field_value must be a dictionary
if is_dataclass(t) and isinstance(field_value, dict):
return from_dict(t, field_value)
return construct_dataclass(t, field_value)
# If the field expects a path
# Field_value must be a string
@@ -161,7 +162,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
# Unwrap the union type
t = unwrap_none_type(t)
# Construct the field value
return construct_field(t, field_value)
return construct_value(t, field_value)
# Nested types
# list
@@ -170,10 +171,10 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
if not isinstance(field_value, list):
raise ClanError(f"Expected list, got {field_value}", location=f"{loc}")
return [construct_field(get_args(t)[0], item) for item in field_value]
return [construct_value(get_args(t)[0], item) for item in field_value]
elif get_origin(t) is dict and isinstance(field_value, dict):
return {
key: construct_field(get_args(t)[1], value)
key: construct_value(get_args(t)[1], value)
for key, value in field_value.items()
}
elif get_origin(t) is Literal:
@@ -186,7 +187,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
elif get_origin(t) is Annotated:
(base_type,) = get_args(t)
return construct_field(base_type, field_value)
return construct_value(base_type, field_value)
# elif get_origin(t) is Union:
@@ -195,7 +196,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
raise ClanError(f"Unhandled field type {t} with value {field_value}")
def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
def construct_dataclass(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
"""
type t MUST be a dataclass
Dynamically instantiate a data class from a dictionary, handling nested data classes.
@@ -231,7 +232,7 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
):
field_values[field.name] = None
else:
field_values[field.name] = construct_field(field_type, field_value)
field_values[field.name] = construct_value(field_type, field_value)
# Check that all required field are present.
for field_name in required:
@@ -242,3 +243,12 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
)
return t(**field_values) # type: ignore
def from_dict(t: type[G], data: dict[str, Any] | Any, path: list[str] = []) -> G:
if is_dataclass(t):
if not isinstance(data, dict):
raise ClanError(f"{data} is not a dict. Expected {t}")
return construct_dataclass(t, data, path)
else:
return construct_value(t, data, path)