API/serde: add support for TypedDict

This commit is contained in:
Johannes Kirschbauer
2024-12-06 17:25:17 +01:00
parent acb0c666a2
commit 1306fa1616
2 changed files with 59 additions and 1 deletions

View File

@@ -3,15 +3,19 @@ import dataclasses
import pathlib import pathlib
from dataclasses import MISSING from dataclasses import MISSING
from enum import EnumType from enum import EnumType
from inspect import get_annotations
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
Literal, Literal,
NotRequired,
Required,
TypeVar, TypeVar,
Union, Union,
get_args, get_args,
get_origin, get_origin,
is_typeddict,
) )
@@ -68,6 +72,33 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
return schema return schema
def is_typed_dict(t: type) -> bool:
return is_typeddict(t)
# Function to get member names and their types
def get_typed_dict_fields(typed_dict_class: type, scope: str) -> dict[str, type]:
"""Retrieve member names and their types from a TypedDict."""
if not hasattr(typed_dict_class, "__annotations__"):
msg = f"{typed_dict_class} is not a TypedDict."
raise JSchemaTypeError(msg, scope)
return get_annotations(typed_dict_class)
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
if get_origin(union_type) is UnionType:
return any(issubclass(arg, target_type) for arg in get_args(union_type))
return union_type == target_type
def is_total(typed_dict_class: type) -> bool:
"""
Check if a TypedDict has total=true
https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false
"""
return getattr(typed_dict_class, "__total__", True) # Default to True if not set
def type_to_dict( def type_to_dict(
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
) -> dict: ) -> dict:
@@ -116,6 +147,28 @@ def type_to_dict(
"additionalProperties": False, "additionalProperties": False,
} }
if is_typed_dict(t):
dict_fields = get_typed_dict_fields(t, scope)
dict_properties: dict = {}
dict_required: list[str] = []
for field_name, field_type in dict_fields.items():
if (
not is_type_in_union(field_type, type(None))
and get_origin(field_type) is not NotRequired
) or get_origin(field_type) is Required:
dict_required.append(field_name)
dict_properties[field_name] = type_to_dict(
field_type, f"{scope} {t.__name__}.{field_name}", type_map
)
return {
"type": "object",
"properties": dict_properties,
"required": dict_required if is_total(t) else [],
"additionalProperties": False,
}
if type(t) is UnionType: if type(t) is UnionType:
return { return {
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__], "oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
@@ -164,6 +217,11 @@ def type_to_dict(
"items": type_to_dict(t.__args__[0], scope, type_map), "items": type_to_dict(t.__args__[0], scope, type_map),
} }
# Used to mark optional fields in TypedDict
# Here we just unwrap the type and return the schema for the inner type
if origin is NotRequired or origin is Required:
return type_to_dict(t.__args__[0], scope, type_map)
if issubclass(origin, dict): if issubclass(origin, dict):
value_type = t.__args__[1] value_type = t.__args__[1]
if value_type is Any: if value_type is Any:

View File

@@ -384,7 +384,7 @@ def patch_inventory_with(base_dir: Path, section: str, content: dict[str, Any])
def set_inventory( def set_inventory(
inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str
) -> None: ) -> None:
""" " """
Write the inventory to the flake directory Write the inventory to the flake directory
and commit it to git with the given message and commit it to git with the given message
""" """