From 08f6cdc43f7ebd03b75bf58b53a70b2b56c8f197 Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Fri, 3 Jan 2025 16:31:25 +0100 Subject: [PATCH] Serde: fix enum type conversion, ensure roundtrip stability --- pkgs/clan-cli/clan_cli/api/serde.py | 16 +++++++------ pkgs/clan-cli/clan_cli/api/util.py | 5 +++- pkgs/clan-cli/tests/test_deserializers.py | 28 +++++++++++++++++++++++ 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 49b2f256b..e15be8e2d 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -233,16 +233,18 @@ def construct_value( # Enums if origin is Enum: - if field_value not in origin.__members__: - msg = f"Expected one of {', '.join(origin.__members__)}, got {field_value}" - raise ClanError(msg, location=f"{loc}") - return origin.__members__[field_value] # type: ignore + try: + return t(field_value) # type: ignore + except ValueError: + msg = f"Expected one of {', '.join(str(origin))}, got {field_value}" + raise ClanError(msg, location=f"{loc}") from ValueError if isinstance(t, type) and issubclass(t, Enum): - if field_value not in t.__members__: + try: + return t(field_value) # type: ignore + except ValueError: msg = f"Expected one of {', '.join(t.__members__)}, got {field_value}" - raise ClanError(msg, location=f"{loc}") - return t.__members__[field_value] # type: ignore + raise ClanError(msg, location=f"{loc}") from ValueError if origin is Annotated: (base_type,) = get_args(t) diff --git a/pkgs/clan-cli/clan_cli/api/util.py b/pkgs/clan-cli/clan_cli/api/util.py index d0812af06..404f5e39f 100644 --- a/pkgs/clan-cli/clan_cli/api/util.py +++ b/pkgs/clan-cli/clan_cli/api/util.py @@ -18,6 +18,8 @@ from typing import ( is_typeddict, ) +from clan_cli.api.serde import dataclass_to_dict + class JSchemaTypeError(Exception): pass @@ -257,7 +259,8 @@ def type_to_dict( if type(t) is EnumType: return { "type": "string", - "enum": list(t.__members__), + # Construct every enum value and use the same method as the serde module for converting it into the same literal string + "enum": [dataclass_to_dict(t(value)) for value in t], # type: ignore } if t is Any: msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}" diff --git a/pkgs/clan-cli/tests/test_deserializers.py b/pkgs/clan-cli/tests/test_deserializers.py index ecbbfa06a..bd6e577ff 100644 --- a/pkgs/clan-cli/tests/test_deserializers.py +++ b/pkgs/clan-cli/tests/test_deserializers.py @@ -266,3 +266,31 @@ def test_literal_field() -> None: with pytest.raises(ClanError): # Not a valid value from_dict(Person, {"name": "open"}) + + +def test_enum_roundtrip() -> None: + from enum import Enum + + class MyEnum(Enum): + FOO = "abc" + BAR = 2 + + @dataclass + class Person: + name: MyEnum + + # Both are equivalent + data = {"name": "abc"} # JSON Representation + expected = Person(name=MyEnum.FOO) # Data representation + + assert from_dict(Person, data) == expected + + assert dataclass_to_dict(expected) == data + + # Same test for integer values + data2 = {"name": 2} # JSON Representation + expected2 = Person(name=MyEnum.BAR) # Data representation + + assert from_dict(Person, data2) == expected2 + + assert dataclass_to_dict(expected2) == data2