API/serde: add support for TypedDict

This commit is contained in:
Johannes Kirschbauer
2024-12-06 17:25:17 +01:00
parent acb0c666a2
commit 1306fa1616
2 changed files with 59 additions and 1 deletions

View File

@@ -3,15 +3,19 @@ import dataclasses
import pathlib
from dataclasses import MISSING
from enum import EnumType
from inspect import get_annotations
from types import NoneType, UnionType
from typing import (
Annotated,
Any,
Literal,
NotRequired,
Required,
TypeVar,
Union,
get_args,
get_origin,
is_typeddict,
)
@@ -68,6 +72,33 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
return schema
def is_typed_dict(t: type) -> bool:
return is_typeddict(t)
# Function to get member names and their types
def get_typed_dict_fields(typed_dict_class: type, scope: str) -> dict[str, type]:
"""Retrieve member names and their types from a TypedDict."""
if not hasattr(typed_dict_class, "__annotations__"):
msg = f"{typed_dict_class} is not a TypedDict."
raise JSchemaTypeError(msg, scope)
return get_annotations(typed_dict_class)
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
if get_origin(union_type) is UnionType:
return any(issubclass(arg, target_type) for arg in get_args(union_type))
return union_type == target_type
def is_total(typed_dict_class: type) -> bool:
"""
Check if a TypedDict has total=true
https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false
"""
return getattr(typed_dict_class, "__total__", True) # Default to True if not set
def type_to_dict(
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
) -> dict:
@@ -116,6 +147,28 @@ def type_to_dict(
"additionalProperties": False,
}
if is_typed_dict(t):
dict_fields = get_typed_dict_fields(t, scope)
dict_properties: dict = {}
dict_required: list[str] = []
for field_name, field_type in dict_fields.items():
if (
not is_type_in_union(field_type, type(None))
and get_origin(field_type) is not NotRequired
) or get_origin(field_type) is Required:
dict_required.append(field_name)
dict_properties[field_name] = type_to_dict(
field_type, f"{scope} {t.__name__}.{field_name}", type_map
)
return {
"type": "object",
"properties": dict_properties,
"required": dict_required if is_total(t) else [],
"additionalProperties": False,
}
if type(t) is UnionType:
return {
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
@@ -164,6 +217,11 @@ def type_to_dict(
"items": type_to_dict(t.__args__[0], scope, type_map),
}
# Used to mark optional fields in TypedDict
# Here we just unwrap the type and return the schema for the inner type
if origin is NotRequired or origin is Required:
return type_to_dict(t.__args__[0], scope, type_map)
if issubclass(origin, dict):
value_type = t.__args__[1]
if value_type is Any:

View File

@@ -384,7 +384,7 @@ def patch_inventory_with(base_dir: Path, section: str, content: dict[str, Any])
def set_inventory(
inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str
) -> None:
""" "
"""
Write the inventory to the flake directory
and commit it to git with the given message
"""