diff --git a/pkgs/clan-cli/clan_lib/api/serde.py b/pkgs/clan-cli/clan_lib/api/serde.py index d85700dee..f5940e065 100644 --- a/pkgs/clan-cli/clan_lib/api/serde.py +++ b/pkgs/clan-cli/clan_lib/api/serde.py @@ -146,8 +146,31 @@ def is_union_type(type_hint: type | UnionType) -> bool: def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool: - if get_origin(union_type) is UnionType: - return any(issubclass(arg, target_type) for arg in get_args(union_type)) + # Check for Union from typing module (Union[str, None]) or UnionType (str | None) + if get_origin(union_type) in (Union, UnionType): + args = get_args(union_type) + for arg in args: + # Handle None type specially since it's not a class + if arg is None or arg is type(None): + if target_type is type(None): + return True + # For generic types like dict[str, str], check their origin + elif get_origin(arg) is not None: + if get_origin(arg) == target_type: + return True + # Also check if target_type is a generic with same origin + elif get_origin(target_type) is not None and get_origin( + arg + ) == get_origin(target_type): + return True + # For actual classes, use issubclass + elif inspect.isclass(arg) and inspect.isclass(target_type): + if issubclass(arg, target_type): + return True + # For non-class types, use direct comparison + elif arg == target_type: + return True + return False return union_type == target_type diff --git a/pkgs/clan-cli/clan_lib/api/serde_deserialize_test.py b/pkgs/clan-cli/clan_lib/api/serde_deserialize_test.py index 6b8065154..817e5a86c 100644 --- a/pkgs/clan-cli/clan_lib/api/serde_deserialize_test.py +++ b/pkgs/clan-cli/clan_lib/api/serde_deserialize_test.py @@ -8,6 +8,7 @@ import pytest from clan_lib.api import dataclass_to_dict, from_dict from clan_lib.errors import ClanError from clan_lib.machines import machines +from clan_lib.api.serde import is_type_in_union def test_simple() -> None: @@ -216,6 +217,44 @@ def test_none_or_string() -> None: assert checked3 is None +def test_union_with_none_edge_cases() -> None: + """ + Test various union types with None to ensure issubclass() error is avoided. + This specifically tests the fix for the TypeError in is_type_in_union. + """ + # Test basic types with None + assert from_dict(str | None, None) is None + assert from_dict(str | None, "hello") == "hello" + + # Test dict with None - this was the specific case that failed + assert from_dict(dict[str, str] | None, None) is None + assert from_dict(dict[str, str] | None, {"key": "value"}) == {"key": "value"} + + # Test list with None + assert from_dict(list[str] | None, None) is None + assert from_dict(list[str] | None, ["a", "b"]) == ["a", "b"] + + # Test dataclass with None + @dataclass + class TestClass: + value: str + + assert from_dict(TestClass | None, None) is None + assert from_dict(TestClass | None, {"value": "test"}) == TestClass(value="test") + + # Test Path with None (since it's used in the original failing test) + assert from_dict(Path | None, None) is None + assert from_dict(Path | None, "/home/test") == Path("/home/test") + + # Test that the is_type_in_union function works correctly + # This is the core of what was fixed - ensuring None doesn't cause issubclass error + # These should not raise TypeError anymore + assert is_type_in_union(str | None, type(None)) is True + assert is_type_in_union(dict[str, str] | None, type(None)) is True + assert is_type_in_union(list[str] | None, type(None)) is True + assert is_type_in_union(Path | None, type(None)) is True + + def test_roundtrip_escape() -> None: assert from_dict(str, "\n") == "\n" assert dataclass_to_dict("\n") == "\n"