API: fix serialization of union types
Due to this bug in serde.py, the run_generators API id not work for the frontend
This commit is contained in:
@@ -30,6 +30,7 @@ Note: This module assumes the presence of other modules and classes such as `Cla
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import traceback
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@@ -40,6 +41,7 @@ from typing import (
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_typeddict,
|
||||
@@ -180,6 +182,19 @@ def unwrap_none_type(type_hint: type | UnionType) -> type:
|
||||
return type_hint # type: ignore
|
||||
|
||||
|
||||
def unwrap_union_type(type_hint: type | UnionType) -> list[type]:
|
||||
"""Takes a type union and returns the first non-None type.
|
||||
None | str
|
||||
=>
|
||||
str
|
||||
"""
|
||||
if is_union_type(type_hint):
|
||||
# Return the first non-None type
|
||||
return list(get_args(type_hint))
|
||||
|
||||
return [type_hint] # type: ignore
|
||||
|
||||
|
||||
JsonValue = str | float | dict[str, Any] | list[Any] | None
|
||||
|
||||
|
||||
@@ -259,9 +274,19 @@ def construct_value(
|
||||
# Union types construct the first non-None type
|
||||
if is_union_type(t):
|
||||
# Unwrap the union type
|
||||
inner = unwrap_none_type(t)
|
||||
inner_types = unwrap_union_type(t)
|
||||
# Construct the field value
|
||||
return construct_value(inner, field_value)
|
||||
errors = []
|
||||
for t in inner_types:
|
||||
try:
|
||||
return construct_value(t, field_value, loc)
|
||||
except ClanError as exc:
|
||||
errors.append(exc)
|
||||
continue
|
||||
msg = f"Cannot construct field of type {t} while constructing a union type from value: {field_value}"
|
||||
for e in errors:
|
||||
traceback.print_exception(e)
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
|
||||
# Nested types
|
||||
# list
|
||||
@@ -348,7 +373,6 @@ def construct_dataclass[T: Any](
|
||||
continue
|
||||
# The first type in a Union
|
||||
# str <- None | str | Path
|
||||
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
|
||||
data_field_name = field.metadata.get("alias", field.name)
|
||||
|
||||
if (
|
||||
@@ -367,7 +391,9 @@ def construct_dataclass[T: Any](
|
||||
):
|
||||
field_values[field.name] = None
|
||||
else:
|
||||
field_values[field.name] = construct_value(field_type, field_value)
|
||||
field_values[field.name] = construct_value(
|
||||
cast(type, field.type), field_value
|
||||
)
|
||||
|
||||
# Check that all required field are present.
|
||||
for field_name in required:
|
||||
|
||||
@@ -387,3 +387,22 @@ def test_unknown_serialize() -> None:
|
||||
|
||||
person = dataclass_to_dict(data)
|
||||
assert person == {"name": ["a", "b"]}
|
||||
|
||||
|
||||
def test_union_dataclass() -> None:
|
||||
@dataclass
|
||||
class A:
|
||||
val: str | list[str] | None = None
|
||||
|
||||
data1 = {"val": "hello"}
|
||||
expected1 = A(val="hello")
|
||||
assert from_dict(A, data1) == expected1
|
||||
data2 = {"val": ["a", "b"]}
|
||||
expected2 = A(val=["a", "b"])
|
||||
assert from_dict(A, data2) == expected2
|
||||
data3 = {"val": None}
|
||||
expected3 = A(val=None)
|
||||
assert from_dict(A, data3) == expected3
|
||||
data4: dict[str, object] = {}
|
||||
expected4 = A(val=None)
|
||||
assert from_dict(A, data4) == expected4
|
||||
|
||||
Reference in New Issue
Block a user