Deserializer: add Literal; Annotated fields

This commit is contained in:
Johannes Kirschbauer
2024-07-31 13:00:28 +02:00
parent 6814946efa
commit 8ddfaba599
2 changed files with 35 additions and 1 deletions

View File

@@ -35,7 +35,9 @@ from dataclasses import dataclass, fields, is_dataclass
from pathlib import Path from pathlib import Path
from types import UnionType from types import UnionType
from typing import ( from typing import (
Annotated,
Any, Any,
Literal,
TypeVar, TypeVar,
Union, Union,
get_args, get_args,
@@ -130,7 +132,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
# If the field expects a path # If the field expects a path
# Field_value must be a string # Field_value must be a string
elif issubclass(t, Path) or is_type_in_union(t, Path): elif is_type_in_union(t, Path):
if not isinstance(field_value, str): if not isinstance(field_value, str):
raise ClanError( raise ClanError(
f"Expected string, cannot construct pathlib.Path() from: {field_value} ", f"Expected string, cannot construct pathlib.Path() from: {field_value} ",
@@ -150,6 +152,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
return int(field_value) # type: ignore return int(field_value) # type: ignore
elif t is float and not isinstance(field_value, str): elif t is float and not isinstance(field_value, str):
return float(field_value) # type: ignore return float(field_value) # type: ignore
elif t is bool and isinstance(field_value, bool):
return field_value # type: ignore
# Union types construct the first non-None type # Union types construct the first non-None type
elif is_union_type(t): elif is_union_type(t):
@@ -171,6 +175,19 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
key: construct_field(get_args(t)[1], value) key: construct_field(get_args(t)[1], value)
for key, value in field_value.items() for key, value in field_value.items()
} }
elif get_origin(t) is Literal:
valid_values = get_args(t)
if field_value not in valid_values:
raise ClanError(
f"Expected one of {valid_values}, got {field_value}", location=f"{loc}"
)
return field_value
elif get_origin(t) is Annotated:
(base_type,) = get_args(t)
return construct_field(base_type, field_value)
# elif get_origin(t) is Union:
# Unhandled # Unhandled
else: else:

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Literal
import pytest import pytest
@@ -200,3 +201,19 @@ def test_private_public_fields() -> None:
assert from_dict(Person, data) == expected assert from_dict(Person, data) == expected
assert dataclass_to_dict(expected) == data assert dataclass_to_dict(expected) == data
def test_literal_field() -> None:
@dataclass
class Person:
name: Literal["open_file", "select_folder", "save"]
data = {"name": "open_file"}
expected = Person(name="open_file")
assert from_dict(Person, data) == expected
assert dataclass_to_dict(expected) == data
with pytest.raises(ClanError):
# Not a valid value
from_dict(Person, {"name": "open"})