API: fix serialization of union types

Due to this bug in serde.py, the run_generators API id not work for the frontend
This commit is contained in:
DavHau
2025-08-26 15:16:55 +07:00
parent d11d83f699
commit 2f1dc3a33d
2 changed files with 49 additions and 4 deletions

View File

@@ -30,6 +30,7 @@ Note: This module assumes the presence of other modules and classes such as `Cla
import dataclasses
import inspect
import traceback
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum
from pathlib import Path
@@ -40,6 +41,7 @@ from typing import (
Literal,
TypeVar,
Union,
cast,
get_args,
get_origin,
is_typeddict,
@@ -180,6 +182,19 @@ def unwrap_none_type(type_hint: type | UnionType) -> type:
return type_hint # type: ignore
def unwrap_union_type(type_hint: type | UnionType) -> list[type]:
"""Takes a type union and returns the first non-None type.
None | str
=>
str
"""
if is_union_type(type_hint):
# Return the first non-None type
return list(get_args(type_hint))
return [type_hint] # type: ignore
JsonValue = str | float | dict[str, Any] | list[Any] | None
@@ -259,9 +274,19 @@ def construct_value(
# Union types construct the first non-None type
if is_union_type(t):
# Unwrap the union type
inner = unwrap_none_type(t)
inner_types = unwrap_union_type(t)
# Construct the field value
return construct_value(inner, field_value)
errors = []
for t in inner_types:
try:
return construct_value(t, field_value, loc)
except ClanError as exc:
errors.append(exc)
continue
msg = f"Cannot construct field of type {t} while constructing a union type from value: {field_value}"
for e in errors:
traceback.print_exception(e)
raise ClanError(msg, location=f"{loc}")
# Nested types
# list
@@ -348,7 +373,6 @@ def construct_dataclass[T: Any](
continue
# The first type in a Union
# str <- None | str | Path
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
data_field_name = field.metadata.get("alias", field.name)
if (
@@ -367,7 +391,9 @@ def construct_dataclass[T: Any](
):
field_values[field.name] = None
else:
field_values[field.name] = construct_value(field_type, field_value)
field_values[field.name] = construct_value(
cast(type, field.type), field_value
)
# Check that all required field are present.
for field_name in required: