API/serde: add support for TypedDict
This commit is contained in:
@@ -3,15 +3,19 @@ import dataclasses
|
|||||||
import pathlib
|
import pathlib
|
||||||
from dataclasses import MISSING
|
from dataclasses import MISSING
|
||||||
from enum import EnumType
|
from enum import EnumType
|
||||||
|
from inspect import get_annotations
|
||||||
from types import NoneType, UnionType
|
from types import NoneType, UnionType
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Literal,
|
Literal,
|
||||||
|
NotRequired,
|
||||||
|
Required,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
|
is_typeddict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +72,33 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def is_typed_dict(t: type) -> bool:
|
||||||
|
return is_typeddict(t)
|
||||||
|
|
||||||
|
|
||||||
|
# Function to get member names and their types
|
||||||
|
def get_typed_dict_fields(typed_dict_class: type, scope: str) -> dict[str, type]:
|
||||||
|
"""Retrieve member names and their types from a TypedDict."""
|
||||||
|
if not hasattr(typed_dict_class, "__annotations__"):
|
||||||
|
msg = f"{typed_dict_class} is not a TypedDict."
|
||||||
|
raise JSchemaTypeError(msg, scope)
|
||||||
|
return get_annotations(typed_dict_class)
|
||||||
|
|
||||||
|
|
||||||
|
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
||||||
|
if get_origin(union_type) is UnionType:
|
||||||
|
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
||||||
|
return union_type == target_type
|
||||||
|
|
||||||
|
|
||||||
|
def is_total(typed_dict_class: type) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a TypedDict has total=true
|
||||||
|
https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false
|
||||||
|
"""
|
||||||
|
return getattr(typed_dict_class, "__total__", True) # Default to True if not set
|
||||||
|
|
||||||
|
|
||||||
def type_to_dict(
|
def type_to_dict(
|
||||||
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
|
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -116,6 +147,28 @@ def type_to_dict(
|
|||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if is_typed_dict(t):
|
||||||
|
dict_fields = get_typed_dict_fields(t, scope)
|
||||||
|
dict_properties: dict = {}
|
||||||
|
dict_required: list[str] = []
|
||||||
|
for field_name, field_type in dict_fields.items():
|
||||||
|
if (
|
||||||
|
not is_type_in_union(field_type, type(None))
|
||||||
|
and get_origin(field_type) is not NotRequired
|
||||||
|
) or get_origin(field_type) is Required:
|
||||||
|
dict_required.append(field_name)
|
||||||
|
|
||||||
|
dict_properties[field_name] = type_to_dict(
|
||||||
|
field_type, f"{scope} {t.__name__}.{field_name}", type_map
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": dict_properties,
|
||||||
|
"required": dict_required if is_total(t) else [],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
if type(t) is UnionType:
|
if type(t) is UnionType:
|
||||||
return {
|
return {
|
||||||
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
|
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
|
||||||
@@ -164,6 +217,11 @@ def type_to_dict(
|
|||||||
"items": type_to_dict(t.__args__[0], scope, type_map),
|
"items": type_to_dict(t.__args__[0], scope, type_map),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Used to mark optional fields in TypedDict
|
||||||
|
# Here we just unwrap the type and return the schema for the inner type
|
||||||
|
if origin is NotRequired or origin is Required:
|
||||||
|
return type_to_dict(t.__args__[0], scope, type_map)
|
||||||
|
|
||||||
if issubclass(origin, dict):
|
if issubclass(origin, dict):
|
||||||
value_type = t.__args__[1]
|
value_type = t.__args__[1]
|
||||||
if value_type is Any:
|
if value_type is Any:
|
||||||
|
|||||||
@@ -384,7 +384,7 @@ def patch_inventory_with(base_dir: Path, section: str, content: dict[str, Any])
|
|||||||
def set_inventory(
|
def set_inventory(
|
||||||
inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str
|
inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str
|
||||||
) -> None:
|
) -> None:
|
||||||
""" "
|
"""
|
||||||
Write the inventory to the flake directory
|
Write the inventory to the flake directory
|
||||||
and commit it to git with the given message
|
and commit it to git with the given message
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user