Merge pull request 'API: fix serialization of union types' (#4963) from serde into main

Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/4963
This commit is contained in:
DavHau
2025-08-26 08:26:13 +00:00
3 changed files with 71 additions and 6 deletions

View File

@@ -671,34 +671,50 @@ def test_prompt(
monkeypatch: pytest.MonkeyPatch,
flake_with_sops: ClanFlake,
) -> None:
"""Test that generators can use prompts to collect user input and store the values appropriately."""
flake = flake_with_sops
# Configure the machine and generator
config = flake.machines["my_machine"]
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"]
my_generator["files"]["line_value"]["secret"] = False
my_generator["files"]["multiline_value"]["secret"] = False
# Define output files - these will contain the prompt responses
my_generator["files"]["line_value"]["secret"] = False # Public file
my_generator["files"]["multiline_value"]["secret"] = False # Public file
# Configure prompts that will collect user input
# prompt1: Single line input, not persisted (temporary)
my_generator["prompts"]["prompt1"]["description"] = "dream2nix"
my_generator["prompts"]["prompt1"]["persist"] = False
my_generator["prompts"]["prompt1"]["type"] = "line"
# prompt2: Single line input, not persisted (temporary)
my_generator["prompts"]["prompt2"]["description"] = "dream2nix"
my_generator["prompts"]["prompt2"]["persist"] = False
my_generator["prompts"]["prompt2"]["type"] = "line"
# prompt_persist: This prompt will be stored as a secret for reuse
my_generator["prompts"]["prompt_persist"]["persist"] = True
# Script that reads prompt responses and writes them to output files
my_generator["script"] = (
'cat "$prompts"/prompt1 > "$out"/line_value; cat "$prompts"/prompt2 > "$out"/multiline_value'
)
flake.refresh()
monkeypatch.chdir(flake.path)
# Mock the prompt responses to simulate user input
monkeypatch.setattr(
"clan_cli.vars.prompt.MOCK_PROMPT_RESPONSE",
iter(["line input", "my\nmultiline\ninput\n", "prompt_persist"]),
)
# Run the generator which will collect prompts and generate vars
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
# Set up objects for testing the results
flake_obj = Flake(str(flake.path))
my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj)
my_generator_with_details = Generator(
@@ -708,6 +724,8 @@ def test_prompt(
machine="my_machine",
_flake=flake_obj,
)
# Verify that non-persistent prompts created public vars correctly
in_repo_store = in_repo.FactStore(flake=flake_obj)
assert in_repo_store.exists(my_generator, "line_value")
assert in_repo_store.get(my_generator, "line_value").decode() == "line input"
@@ -717,6 +735,8 @@ def test_prompt(
in_repo_store.get(my_generator, "multiline_value").decode()
== "my\nmultiline\ninput\n"
)
# Verify that persistent prompt was stored as a secret
sops_store = sops.SecretStore(flake=flake_obj)
assert sops_store.exists(my_generator_with_details, "prompt_persist")
assert sops_store.get(my_generator, "prompt_persist").decode() == "prompt_persist"

View File

@@ -30,6 +30,7 @@ Note: This module assumes the presence of other modules and classes such as `Cla
import dataclasses
import inspect
import traceback
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum
from pathlib import Path
@@ -40,6 +41,7 @@ from typing import (
Literal,
TypeVar,
Union,
cast,
get_args,
get_origin,
is_typeddict,
@@ -180,6 +182,19 @@ def unwrap_none_type(type_hint: type | UnionType) -> type:
return type_hint # type: ignore
def unwrap_union_type(type_hint: type | UnionType) -> list[type]:
"""Takes a type union and returns the first non-None type.
None | str
=>
str
"""
if is_union_type(type_hint):
# Return the first non-None type
return list(get_args(type_hint))
return [type_hint] # type: ignore
JsonValue = str | float | dict[str, Any] | list[Any] | None
@@ -259,9 +274,19 @@ def construct_value(
# Union types construct the first non-None type
if is_union_type(t):
# Unwrap the union type
inner = unwrap_none_type(t)
inner_types = unwrap_union_type(t)
# Construct the field value
return construct_value(inner, field_value)
errors = []
for t in inner_types:
try:
return construct_value(t, field_value, loc)
except ClanError as exc:
errors.append(exc)
continue
msg = f"Cannot construct field of type {t} while constructing a union type from value: {field_value}"
for e in errors:
traceback.print_exception(e)
raise ClanError(msg, location=f"{loc}")
# Nested types
# list
@@ -348,7 +373,6 @@ def construct_dataclass[T: Any](
continue
# The first type in a Union
# str <- None | str | Path
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
data_field_name = field.metadata.get("alias", field.name)
if (
@@ -367,7 +391,9 @@ def construct_dataclass[T: Any](
):
field_values[field.name] = None
else:
field_values[field.name] = construct_value(field_type, field_value)
field_values[field.name] = construct_value(
cast(type, field.type), field_value
)
# Check that all required field are present.
for field_name in required:

View File

@@ -387,3 +387,22 @@ def test_unknown_serialize() -> None:
person = dataclass_to_dict(data)
assert person == {"name": ["a", "b"]}
def test_union_dataclass() -> None:
@dataclass
class A:
val: str | list[str] | None = None
data1 = {"val": "hello"}
expected1 = A(val="hello")
assert from_dict(A, data1) == expected1
data2 = {"val": ["a", "b"]}
expected2 = A(val=["a", "b"])
assert from_dict(A, data2) == expected2
data3 = {"val": None}
expected3 = A(val=None)
assert from_dict(A, data3) == expected3
data4: dict[str, object] = {}
expected4 = A(val=None)
assert from_dict(A, data4) == expected4