feat(classgen): make type generation more predictable across

This commit is contained in:
Johannes Kirschbauer
2025-05-27 10:26:19 +02:00
parent ad1f3bfa92
commit 7957fbaa4f

View File

@@ -3,7 +3,7 @@ import argparse
import json import json
import logging import logging
import sys import sys
from collections.abc import Callable from collections.abc import Callable, Iterable
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -15,58 +15,70 @@ class Error(Exception):
pass pass
def sort_types(items: Iterable[str]) -> list[str]:
def sort_key(item: str) -> tuple[int, str]:
# Priority order: lower number = higher priority
if item.startswith(("dict", "list")):
return (0, item) # Highest priority, dicts and lists should be first
if item == "None":
return (2, item) # Lowest priority, None should be last
return (1, item) # Middle priority, sorted alphabetically
return sorted(items, key=sort_key)
# Function to map JSON schemas and types to Python types # Function to map JSON schemas and types to Python types
def map_json_type( def map_json_type(
json_type: Any, nested_types: set[str] | None = None, parent: Any = None json_type: Any, nested_types: list[str] | None = None, parent: Any = None
) -> set[str]: ) -> list[str]:
if nested_types is None: if nested_types is None:
nested_types = {"Any"} nested_types = ["Any"]
if isinstance(json_type, list): if isinstance(json_type, list):
res = set() res: list[str] = []
for t in json_type: for t in json_type:
res |= map_json_type(t) res.extend(map_json_type(t))
return res return sort_types(set(res))
if isinstance(json_type, dict): if isinstance(json_type, dict):
items = json_type.get("items") items = json_type.get("items")
if items: if items:
nested_types = map_json_type(items) nested_types = map_json_type(items)
return map_json_type(json_type.get("type"), nested_types)
if not json_type.get("type") and json_type.get("tsType") == "unknown":
return ["Unknown"]
return sort_types(map_json_type(json_type.get("type"), nested_types))
if json_type == "string": if json_type == "string":
return {"str"} return ["str"]
if json_type == "integer": if json_type == "integer":
return {"int"} return ["int"]
if json_type == "number": if json_type == "number":
return {"float"} return ["float"]
if json_type == "boolean": if json_type == "boolean":
return {"bool"} return ["bool"]
# In Python, "number" is analogous to the float type. # In Python, "number" is analogous to the float type.
# https://json-schema.org/understanding-json-schema/reference/numeric#number # https://json-schema.org/understanding-json-schema/reference/numeric#number
if json_type == "number": if json_type == "number":
return {"float"} return ["float"]
if json_type == "array": if json_type == "array":
assert nested_types, f"Array type not found for {parent}" assert nested_types, f"Array type not found for {parent}"
return {f"""list[{" | ".join(nested_types)}]"""} return [f"""list[{" | ".join(sort_types(nested_types))}]"""]
if json_type == "object": if json_type == "object":
assert nested_types, f"dict type not found for {parent}" assert nested_types, f"dict type not found for {parent}"
return {f"""dict[str, {" | ".join(nested_types)}]"""} return [f"""dict[str, {" | ".join(sort_types(nested_types))}]"""]
if json_type == "null": if json_type == "null":
return {"None"} return ["None"]
msg = f"Python type not found for {json_type}" msg = f"Python type not found for {json_type}"
raise Error(msg) raise Error(msg)
known_classes = set() known_classes = set()
root_class = "Inventory"
# TODO: make this configurable # TODO: make this configurable
# For now this only includes static top-level attributes of the inventory. root_class = "Clan"
attrs = ["machines", "meta", "services", "instances"]
static: dict[str, str] = {"Service": "dict[str, Any]"}
def field_def_from_default_type( def field_def_from_default_type(
field_name: str, field_name: str,
field_types: set[str], field_types: list[str],
class_name: str, class_name: str,
finalize_field: Callable[..., tuple[str, str]], finalize_field: Callable[..., tuple[str, str]],
) -> tuple[str, str] | None: ) -> tuple[str, str] | None:
@@ -128,7 +140,7 @@ def field_def_from_default_type(
def field_def_from_default_value( def field_def_from_default_value(
default_value: Any, default_value: Any,
field_name: str, field_name: str,
field_types: set[str], field_types: list[str],
nested_class_name: str, nested_class_name: str,
finalize_field: Callable[..., tuple[str, str]], finalize_field: Callable[..., tuple[str, str]],
) -> tuple[str, str] | None: ) -> tuple[str, str] | None:
@@ -141,7 +153,7 @@ def field_def_from_default_value(
) )
if default_value is None: if default_value is None:
return finalize_field( return finalize_field(
field_types=field_types | {"None"}, field_types=[*field_types, "None"],
default="None", default="None",
) )
if isinstance(default_value, list): if isinstance(default_value, list):
@@ -189,7 +201,7 @@ def field_def_from_default_value(
def get_field_def( def get_field_def(
field_name: str, field_name: str,
field_meta: str | None, field_meta: str | None,
field_types: set[str], field_types: list[str],
default: str | None = None, default: str | None = None,
default_factory: str | None = None, default_factory: str | None = None,
type_appendix: str = "", type_appendix: str = "",
@@ -197,10 +209,10 @@ def get_field_def(
if "None" in field_types or default or default_factory: if "None" in field_types or default or default_factory:
if "None" in field_types: if "None" in field_types:
field_types.remove("None") field_types.remove("None")
serialised_types = " | ".join(field_types) + type_appendix serialised_types = " | ".join(sort_types(field_types)) + type_appendix
serialised_types = f"{serialised_types}" serialised_types = f"{serialised_types}"
else: else:
serialised_types = " | ".join(field_types) + type_appendix serialised_types = " | ".join(sort_types(field_types)) + type_appendix
return (field_name, serialised_types) return (field_name, serialised_types)
@@ -217,26 +229,21 @@ def generate_dataclass(
fields_with_default: list[tuple[str, str]] = [] fields_with_default: list[tuple[str, str]] = []
nested_classes: list[str] = [] nested_classes: list[str] = []
# if We are at the top level, and the attribute name is in shallow
# return f"{class_name} = dict[str, Any]"
if class_name in static:
return f"{class_name} = {static[class_name]}"
for prop, prop_info in properties.items(): for prop, prop_info in properties.items():
# If we are at the top level, and the attribute name is not explicitly included we only do shallow # If we are at the top level, and the attribute name is not explicitly included we only do shallow
field_name = prop.replace("-", "_") field_name = prop.replace("-", "_")
if len(attr_path) == 0 and prop not in attrs: # if len(attr_path) == 0 and prop in shallow_attrs:
field_def = field_name, "dict[str, Any]" # field_def = field_name, "dict[str, Any]"
fields_with_default.append(field_def) # fields_with_default.append(field_def)
continue # continue
prop_type = prop_info.get("type", None) prop_type = prop_info.get("type", None)
union_variants = prop_info.get("oneOf", []) union_variants = prop_info.get("oneOf", [])
enum_variants = prop_info.get("enum", []) enum_variants = prop_info.get("enum", [])
# Collect all types # Collect all types
field_types = set() field_types: list[str] = []
title = prop_info.get("title", prop.removesuffix("s")) title = prop_info.get("title", prop.removesuffix("s"))
title_sanitized = "".join([p.capitalize() for p in title.split("-")]) title_sanitized = "".join([p.capitalize() for p in title.split("-")])
@@ -259,13 +266,13 @@ def generate_dataclass(
) )
elif enum := prop_info.get("enum"): elif enum := prop_info.get("enum"):
literals = ", ".join([f'"{string}"' for string in enum]) literals = ", ".join([f'"{string}"' for string in enum])
field_types = {f"""Literal[{literals}]"""} field_types = [f"""Literal[{literals}]"""]
elif prop_type == "object": elif prop_type == "object":
inner_type = prop_info.get("additionalProperties") inner_type = prop_info.get("additionalProperties")
if inner_type and inner_type.get("type") == "object": if inner_type and inner_type.get("type") == "object":
# Inner type is a class # Inner type is a class
field_types = map_json_type(prop_type, {nested_class_name}, field_name) field_types = map_json_type(prop_type, [nested_class_name], field_name)
if nested_class_name not in known_classes: if nested_class_name not in known_classes:
nested_classes.append( nested_classes.append(
@@ -278,13 +285,13 @@ def generate_dataclass(
elif inner_type and inner_type.get("type") != "object": elif inner_type and inner_type.get("type") != "object":
# Trivial type: # Trivial type:
# dict[str, inner_type] # dict[str, inner_type]
field_types = { field_types = [
f"""dict[str, {" | ".join(map_json_type(inner_type))}]""" f"""dict[str, {" | ".join(map_json_type(inner_type))}]"""
} ]
elif not inner_type: elif not inner_type:
# The type is a class # The type is a class
field_types = {nested_class_name} field_types = [nested_class_name]
if nested_class_name not in known_classes: if nested_class_name not in known_classes:
nested_classes.append( nested_classes.append(
generate_dataclass( generate_dataclass(
@@ -293,11 +300,11 @@ def generate_dataclass(
) )
known_classes.add(nested_class_name) known_classes.add(nested_class_name)
elif prop_type == "Unknown": elif prop_type == "Unknown":
field_types = {"Unknown"} field_types = ["Unknown"]
else: else:
field_types = map_json_type( field_types = map_json_type(
prop_type, prop_type,
nested_types=set(), nested_types=[],
parent=field_name, parent=field_name,
) )
@@ -309,6 +316,9 @@ def generate_dataclass(
finalize_field = partial(get_field_def, field_name, field_meta) finalize_field = partial(get_field_def, field_name, field_meta)
# Sort and remove potential duplicates
field_types_sorted = sort_types(set(field_types))
if "default" in prop_info or field_name not in prop_info.get("required", []): if "default" in prop_info or field_name not in prop_info.get("required", []):
if prop_info.get("type") == "object": if prop_info.get("type") == "object":
prop_info.update({"default": {}}) prop_info.update({"default": {}})
@@ -318,7 +328,7 @@ def generate_dataclass(
field_def = field_def_from_default_value( field_def = field_def_from_default_value(
default_value=default_value, default_value=default_value,
field_name=field_name, field_name=field_name,
field_types=field_types, field_types=field_types_sorted,
nested_class_name=nested_class_name, nested_class_name=nested_class_name,
finalize_field=finalize_field, finalize_field=finalize_field,
) )
@@ -328,7 +338,7 @@ def generate_dataclass(
if not field_def: if not field_def:
# Finalize without the default value # Finalize without the default value
field_def = finalize_field( field_def = finalize_field(
field_types=field_types, field_types=field_types_sorted,
) )
required_fields.append(field_def) required_fields.append(field_def)
@@ -337,7 +347,7 @@ def generate_dataclass(
# Trying to infer default value from type # Trying to infer default value from type
field_def = field_def_from_default_type( field_def = field_def_from_default_type(
field_name=field_name, field_name=field_name,
field_types=field_types, field_types=field_types_sorted,
class_name=class_name, class_name=class_name,
finalize_field=finalize_field, finalize_field=finalize_field,
) )
@@ -346,13 +356,13 @@ def generate_dataclass(
fields_with_default.append(field_def) fields_with_default.append(field_def)
if not field_def: if not field_def:
field_def = finalize_field( field_def = finalize_field(
field_types=field_types, field_types=field_types_sorted,
) )
required_fields.append(field_def) required_fields.append(field_def)
else: else:
field_def = finalize_field( field_def = finalize_field(
field_types=field_types, field_types=field_types_sorted,
) )
required_fields.append(field_def) required_fields.append(field_def)