Merge pull request 'API/serde: add support for TypedDict' (#2571) from hsjobeki/clan-core:hsjobeki-main into main
This commit is contained in:
@@ -3,15 +3,19 @@ import dataclasses
|
||||
import pathlib
|
||||
from dataclasses import MISSING
|
||||
from enum import EnumType
|
||||
from inspect import get_annotations
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
NotRequired,
|
||||
Required,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_typeddict,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +72,33 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
|
||||
return schema
|
||||
|
||||
|
||||
def is_typed_dict(t: type) -> bool:
|
||||
return is_typeddict(t)
|
||||
|
||||
|
||||
# Function to get member names and their types
|
||||
def get_typed_dict_fields(typed_dict_class: type, scope: str) -> dict[str, type]:
|
||||
"""Retrieve member names and their types from a TypedDict."""
|
||||
if not hasattr(typed_dict_class, "__annotations__"):
|
||||
msg = f"{typed_dict_class} is not a TypedDict."
|
||||
raise JSchemaTypeError(msg, scope)
|
||||
return get_annotations(typed_dict_class)
|
||||
|
||||
|
||||
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
||||
if get_origin(union_type) is UnionType:
|
||||
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
||||
return union_type == target_type
|
||||
|
||||
|
||||
def is_total(typed_dict_class: type) -> bool:
|
||||
"""
|
||||
Check if a TypedDict has total=true
|
||||
https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false
|
||||
"""
|
||||
return getattr(typed_dict_class, "__total__", True) # Default to True if not set
|
||||
|
||||
|
||||
def type_to_dict(
|
||||
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
|
||||
) -> dict:
|
||||
@@ -116,6 +147,28 @@ def type_to_dict(
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
if is_typed_dict(t):
|
||||
dict_fields = get_typed_dict_fields(t, scope)
|
||||
dict_properties: dict = {}
|
||||
dict_required: list[str] = []
|
||||
for field_name, field_type in dict_fields.items():
|
||||
if (
|
||||
not is_type_in_union(field_type, type(None))
|
||||
and get_origin(field_type) is not NotRequired
|
||||
) or get_origin(field_type) is Required:
|
||||
dict_required.append(field_name)
|
||||
|
||||
dict_properties[field_name] = type_to_dict(
|
||||
field_type, f"{scope} {t.__name__}.{field_name}", type_map
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": dict_properties,
|
||||
"required": dict_required if is_total(t) else [],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
if type(t) is UnionType:
|
||||
return {
|
||||
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
|
||||
@@ -164,6 +217,11 @@ def type_to_dict(
|
||||
"items": type_to_dict(t.__args__[0], scope, type_map),
|
||||
}
|
||||
|
||||
# Used to mark optional fields in TypedDict
|
||||
# Here we just unwrap the type and return the schema for the inner type
|
||||
if origin is NotRequired or origin is Required:
|
||||
return type_to_dict(t.__args__[0], scope, type_map)
|
||||
|
||||
if issubclass(origin, dict):
|
||||
value_type = t.__args__[1]
|
||||
if value_type is Any:
|
||||
|
||||
@@ -384,7 +384,7 @@ def patch_inventory_with(base_dir: Path, section: str, content: dict[str, Any])
|
||||
def set_inventory(
|
||||
inventory: Inventory | dict[str, Any], flake_dir: str | Path, message: str
|
||||
) -> None:
|
||||
""" "
|
||||
"""
|
||||
Write the inventory to the flake directory
|
||||
and commit it to git with the given message
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user