Deserializer: add Literal; Annotated fields
This commit is contained in:
@@ -35,7 +35,9 @@ from dataclasses import dataclass, fields, is_dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import UnionType
|
from types import UnionType
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
|
Literal,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
get_args,
|
get_args,
|
||||||
@@ -130,7 +132,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
|
|||||||
|
|
||||||
# If the field expects a path
|
# If the field expects a path
|
||||||
# Field_value must be a string
|
# Field_value must be a string
|
||||||
elif issubclass(t, Path) or is_type_in_union(t, Path):
|
elif is_type_in_union(t, Path):
|
||||||
if not isinstance(field_value, str):
|
if not isinstance(field_value, str):
|
||||||
raise ClanError(
|
raise ClanError(
|
||||||
f"Expected string, cannot construct pathlib.Path() from: {field_value} ",
|
f"Expected string, cannot construct pathlib.Path() from: {field_value} ",
|
||||||
@@ -150,6 +152,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
|
|||||||
return int(field_value) # type: ignore
|
return int(field_value) # type: ignore
|
||||||
elif t is float and not isinstance(field_value, str):
|
elif t is float and not isinstance(field_value, str):
|
||||||
return float(field_value) # type: ignore
|
return float(field_value) # type: ignore
|
||||||
|
elif t is bool and isinstance(field_value, bool):
|
||||||
|
return field_value # type: ignore
|
||||||
|
|
||||||
# Union types construct the first non-None type
|
# Union types construct the first non-None type
|
||||||
elif is_union_type(t):
|
elif is_union_type(t):
|
||||||
@@ -171,6 +175,19 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any
|
|||||||
key: construct_field(get_args(t)[1], value)
|
key: construct_field(get_args(t)[1], value)
|
||||||
for key, value in field_value.items()
|
for key, value in field_value.items()
|
||||||
}
|
}
|
||||||
|
elif get_origin(t) is Literal:
|
||||||
|
valid_values = get_args(t)
|
||||||
|
if field_value not in valid_values:
|
||||||
|
raise ClanError(
|
||||||
|
f"Expected one of {valid_values}, got {field_value}", location=f"{loc}"
|
||||||
|
)
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
elif get_origin(t) is Annotated:
|
||||||
|
(base_type,) = get_args(t)
|
||||||
|
return construct_field(base_type, field_value)
|
||||||
|
|
||||||
|
# elif get_origin(t) is Union:
|
||||||
|
|
||||||
# Unhandled
|
# Unhandled
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -200,3 +201,19 @@ def test_private_public_fields() -> None:
|
|||||||
assert from_dict(Person, data) == expected
|
assert from_dict(Person, data) == expected
|
||||||
|
|
||||||
assert dataclass_to_dict(expected) == data
|
assert dataclass_to_dict(expected) == data
|
||||||
|
|
||||||
|
|
||||||
|
def test_literal_field() -> None:
|
||||||
|
@dataclass
|
||||||
|
class Person:
|
||||||
|
name: Literal["open_file", "select_folder", "save"]
|
||||||
|
|
||||||
|
data = {"name": "open_file"}
|
||||||
|
expected = Person(name="open_file")
|
||||||
|
assert from_dict(Person, data) == expected
|
||||||
|
|
||||||
|
assert dataclass_to_dict(expected) == data
|
||||||
|
|
||||||
|
with pytest.raises(ClanError):
|
||||||
|
# Not a valid value
|
||||||
|
from_dict(Person, {"name": "open"})
|
||||||
|
|||||||
Reference in New Issue
Block a user