diff --git a/pkgs/clan-cli/clan_cli/api/util.py b/pkgs/clan-cli/clan_cli/api/util.py index 3ff0ffc8c..d0812af06 100644 --- a/pkgs/clan-cli/clan_cli/api/util.py +++ b/pkgs/clan-cli/clan_cli/api/util.py @@ -3,15 +3,19 @@ import dataclasses import pathlib from dataclasses import MISSING from enum import EnumType +from inspect import get_annotations from types import NoneType, UnionType from typing import ( Annotated, Any, Literal, + NotRequired, + Required, TypeVar, Union, get_args, get_origin, + is_typeddict, ) @@ -68,6 +72,33 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st return schema +def is_typed_dict(t: type) -> bool: + return is_typeddict(t) + + +# Function to get member names and their types +def get_typed_dict_fields(typed_dict_class: type, scope: str) -> dict[str, type]: + """Retrieve member names and their types from a TypedDict.""" + if not hasattr(typed_dict_class, "__annotations__"): + msg = f"{typed_dict_class} is not a TypedDict." + raise JSchemaTypeError(msg, scope) + return get_annotations(typed_dict_class) + + +def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool: + if get_origin(union_type) is UnionType: + return any(issubclass(arg, target_type) for arg in get_args(union_type)) + return union_type == target_type + + +def is_total(typed_dict_class: type) -> bool: + """ + Check if a TypedDict has total=true + https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false + """ + return getattr(typed_dict_class, "__total__", True) # Default to True if not set + + def type_to_dict( t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None ) -> dict: @@ -116,6 +147,28 @@ def type_to_dict( "additionalProperties": False, } + if is_typed_dict(t): + dict_fields = get_typed_dict_fields(t, scope) + dict_properties: dict = {} + dict_required: list[str] = [] + for field_name, field_type in dict_fields.items(): + if ( + not is_type_in_union(field_type, type(None)) + and get_origin(field_type) is not NotRequired + ) or get_origin(field_type) is Required: + dict_required.append(field_name) + + dict_properties[field_name] = type_to_dict( + field_type, f"{scope} {t.__name__}.{field_name}", type_map + ) + + return { + "type": "object", + "properties": dict_properties, + "required": dict_required if is_total(t) else [], + "additionalProperties": False, + } + if type(t) is UnionType: return { "oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__], @@ -164,6 +217,11 @@ def type_to_dict( "items": type_to_dict(t.__args__[0], scope, type_map), } + # Used to mark optional fields in TypedDict + # Here we just unwrap the type and return the schema for the inner type + if origin is NotRequired or origin is Required: + return type_to_dict(t.__args__[0], scope, type_map) + if issubclass(origin, dict): value_type = t.__args__[1] if value_type is Any: diff --git a/pkgs/clan-cli/clan_cli/inventory/__init__.py b/pkgs/clan-cli/clan_cli/inventory/__init__.py index 35281425d..4a354746b 100644 --- a/pkgs/clan-cli/clan_cli/inventory/__init__.py +++ b/pkgs/clan-cli/clan_cli/inventory/__init__.py @@ -384,7 +384,7 @@ def patch_inventory_with(base_dir: Path, section: str, content: dict[str, Any]) def set_inventory( inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str ) -> None: - """ " + """ Write the inventory to the flake directory and commit it to git with the given message """