From 6c5f9ca6db2e3476305d9f45a1c812a79ea75710 Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Wed, 31 Jul 2024 13:00:28 +0200 Subject: [PATCH] Deserializer: add Literal; Annotated fields --- pkgs/clan-cli/clan_cli/api/serde.py | 19 ++++++++++++++++++- pkgs/clan-cli/tests/test_deserializers.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 24b9786d5..298cbe9de 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -35,7 +35,9 @@ from dataclasses import dataclass, fields, is_dataclass from pathlib import Path from types import UnionType from typing import ( + Annotated, Any, + Literal, TypeVar, Union, get_args, @@ -130,7 +132,7 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any # If the field expects a path # Field_value must be a string - elif issubclass(t, Path) or is_type_in_union(t, Path): + elif is_type_in_union(t, Path): if not isinstance(field_value, str): raise ClanError( f"Expected string, cannot construct pathlib.Path() from: {field_value} ", @@ -150,6 +152,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any return int(field_value) # type: ignore elif t is float and not isinstance(field_value, str): return float(field_value) # type: ignore + elif t is bool and isinstance(field_value, bool): + return field_value # type: ignore # Union types construct the first non-None type elif is_union_type(t): @@ -171,6 +175,19 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any key: construct_field(get_args(t)[1], value) for key, value in field_value.items() } + elif get_origin(t) is Literal: + valid_values = get_args(t) + if field_value not in valid_values: + raise ClanError( + f"Expected one of {valid_values}, got {field_value}", location=f"{loc}" + ) + return field_value + + elif get_origin(t) is Annotated: + (base_type,) = get_args(t) + return construct_field(base_type, field_value) + + # elif get_origin(t) is Union: # Unhandled else: diff --git a/pkgs/clan-cli/tests/test_deserializers.py b/pkgs/clan-cli/tests/test_deserializers.py index 3d6d7583d..e858b5398 100644 --- a/pkgs/clan-cli/tests/test_deserializers.py +++ b/pkgs/clan-cli/tests/test_deserializers.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path +from typing import Literal import pytest @@ -200,3 +201,19 @@ def test_private_public_fields() -> None: assert from_dict(Person, data) == expected assert dataclass_to_dict(expected) == data + + +def test_literal_field() -> None: + @dataclass + class Person: + name: Literal["open_file", "select_folder", "save"] + + data = {"name": "open_file"} + expected = Person(name="open_file") + assert from_dict(Person, data) == expected + + assert dataclass_to_dict(expected) == data + + with pytest.raises(ClanError): + # Not a valid value + from_dict(Person, {"name": "open"})