diff --git a/pkgs/classgen/main.py b/pkgs/classgen/main.py index e8dac154e..bc15fb94c 100644 --- a/pkgs/classgen/main.py +++ b/pkgs/classgen/main.py @@ -3,7 +3,7 @@ import argparse import json import logging import sys -from collections.abc import Callable +from collections.abc import Callable, Iterable from functools import partial from pathlib import Path from typing import Any @@ -15,58 +15,70 @@ class Error(Exception): pass +def sort_types(items: Iterable[str]) -> list[str]: + def sort_key(item: str) -> tuple[int, str]: + # Priority order: lower number = higher priority + if item.startswith(("dict", "list")): + return (0, item) # Highest priority, dicts and lists should be first + if item == "None": + return (2, item) # Lowest priority, None should be last + return (1, item) # Middle priority, sorted alphabetically + + return sorted(items, key=sort_key) + + # Function to map JSON schemas and types to Python types def map_json_type( - json_type: Any, nested_types: set[str] | None = None, parent: Any = None -) -> set[str]: + json_type: Any, nested_types: list[str] | None = None, parent: Any = None +) -> list[str]: if nested_types is None: - nested_types = {"Any"} + nested_types = ["Any"] if isinstance(json_type, list): - res = set() + res: list[str] = [] for t in json_type: - res |= map_json_type(t) - return res + res.extend(map_json_type(t)) + return sort_types(set(res)) if isinstance(json_type, dict): items = json_type.get("items") if items: nested_types = map_json_type(items) - return map_json_type(json_type.get("type"), nested_types) + + if not json_type.get("type") and json_type.get("tsType") == "unknown": + return ["Unknown"] + + return sort_types(map_json_type(json_type.get("type"), nested_types)) if json_type == "string": - return {"str"} + return ["str"] if json_type == "integer": - return {"int"} + return ["int"] if json_type == "number": - return {"float"} + return ["float"] if json_type == "boolean": - return {"bool"} + return ["bool"] # In Python, "number" is analogous to the float type. # https://json-schema.org/understanding-json-schema/reference/numeric#number if json_type == "number": - return {"float"} + return ["float"] if json_type == "array": assert nested_types, f"Array type not found for {parent}" - return {f"""list[{" | ".join(nested_types)}]"""} + return [f"""list[{" | ".join(sort_types(nested_types))}]"""] if json_type == "object": assert nested_types, f"dict type not found for {parent}" - return {f"""dict[str, {" | ".join(nested_types)}]"""} + return [f"""dict[str, {" | ".join(sort_types(nested_types))}]"""] if json_type == "null": - return {"None"} + return ["None"] msg = f"Python type not found for {json_type}" raise Error(msg) known_classes = set() -root_class = "Inventory" # TODO: make this configurable -# For now this only includes static top-level attributes of the inventory. -attrs = ["machines", "meta", "services", "instances"] - -static: dict[str, str] = {"Service": "dict[str, Any]"} +root_class = "Clan" def field_def_from_default_type( field_name: str, - field_types: set[str], + field_types: list[str], class_name: str, finalize_field: Callable[..., tuple[str, str]], ) -> tuple[str, str] | None: @@ -128,7 +140,7 @@ def field_def_from_default_type( def field_def_from_default_value( default_value: Any, field_name: str, - field_types: set[str], + field_types: list[str], nested_class_name: str, finalize_field: Callable[..., tuple[str, str]], ) -> tuple[str, str] | None: @@ -141,7 +153,7 @@ def field_def_from_default_value( ) if default_value is None: return finalize_field( - field_types=field_types | {"None"}, + field_types=[*field_types, "None"], default="None", ) if isinstance(default_value, list): @@ -189,7 +201,7 @@ def field_def_from_default_value( def get_field_def( field_name: str, field_meta: str | None, - field_types: set[str], + field_types: list[str], default: str | None = None, default_factory: str | None = None, type_appendix: str = "", @@ -197,10 +209,10 @@ def get_field_def( if "None" in field_types or default or default_factory: if "None" in field_types: field_types.remove("None") - serialised_types = " | ".join(field_types) + type_appendix + serialised_types = " | ".join(sort_types(field_types)) + type_appendix serialised_types = f"{serialised_types}" else: - serialised_types = " | ".join(field_types) + type_appendix + serialised_types = " | ".join(sort_types(field_types)) + type_appendix return (field_name, serialised_types) @@ -217,26 +229,21 @@ def generate_dataclass( fields_with_default: list[tuple[str, str]] = [] nested_classes: list[str] = [] - # if We are at the top level, and the attribute name is in shallow - # return f"{class_name} = dict[str, Any]" - if class_name in static: - return f"{class_name} = {static[class_name]}" - for prop, prop_info in properties.items(): # If we are at the top level, and the attribute name is not explicitly included we only do shallow field_name = prop.replace("-", "_") - if len(attr_path) == 0 and prop not in attrs: - field_def = field_name, "dict[str, Any]" - fields_with_default.append(field_def) - continue + # if len(attr_path) == 0 and prop in shallow_attrs: + # field_def = field_name, "dict[str, Any]" + # fields_with_default.append(field_def) + # continue prop_type = prop_info.get("type", None) union_variants = prop_info.get("oneOf", []) enum_variants = prop_info.get("enum", []) # Collect all types - field_types = set() + field_types: list[str] = [] title = prop_info.get("title", prop.removesuffix("s")) title_sanitized = "".join([p.capitalize() for p in title.split("-")]) @@ -259,13 +266,13 @@ def generate_dataclass( ) elif enum := prop_info.get("enum"): literals = ", ".join([f'"{string}"' for string in enum]) - field_types = {f"""Literal[{literals}]"""} + field_types = [f"""Literal[{literals}]"""] elif prop_type == "object": inner_type = prop_info.get("additionalProperties") if inner_type and inner_type.get("type") == "object": # Inner type is a class - field_types = map_json_type(prop_type, {nested_class_name}, field_name) + field_types = map_json_type(prop_type, [nested_class_name], field_name) if nested_class_name not in known_classes: nested_classes.append( @@ -278,13 +285,13 @@ def generate_dataclass( elif inner_type and inner_type.get("type") != "object": # Trivial type: # dict[str, inner_type] - field_types = { + field_types = [ f"""dict[str, {" | ".join(map_json_type(inner_type))}]""" - } + ] elif not inner_type: # The type is a class - field_types = {nested_class_name} + field_types = [nested_class_name] if nested_class_name not in known_classes: nested_classes.append( generate_dataclass( @@ -293,11 +300,11 @@ def generate_dataclass( ) known_classes.add(nested_class_name) elif prop_type == "Unknown": - field_types = {"Unknown"} + field_types = ["Unknown"] else: field_types = map_json_type( prop_type, - nested_types=set(), + nested_types=[], parent=field_name, ) @@ -309,6 +316,9 @@ def generate_dataclass( finalize_field = partial(get_field_def, field_name, field_meta) + # Sort and remove potential duplicates + field_types_sorted = sort_types(set(field_types)) + if "default" in prop_info or field_name not in prop_info.get("required", []): if prop_info.get("type") == "object": prop_info.update({"default": {}}) @@ -318,7 +328,7 @@ def generate_dataclass( field_def = field_def_from_default_value( default_value=default_value, field_name=field_name, - field_types=field_types, + field_types=field_types_sorted, nested_class_name=nested_class_name, finalize_field=finalize_field, ) @@ -328,7 +338,7 @@ def generate_dataclass( if not field_def: # Finalize without the default value field_def = finalize_field( - field_types=field_types, + field_types=field_types_sorted, ) required_fields.append(field_def) @@ -337,7 +347,7 @@ def generate_dataclass( # Trying to infer default value from type field_def = field_def_from_default_type( field_name=field_name, - field_types=field_types, + field_types=field_types_sorted, class_name=class_name, finalize_field=finalize_field, ) @@ -346,13 +356,13 @@ def generate_dataclass( fields_with_default.append(field_def) if not field_def: field_def = finalize_field( - field_types=field_types, + field_types=field_types_sorted, ) required_fields.append(field_def) else: field_def = finalize_field( - field_types=field_types, + field_types=field_types_sorted, ) required_fields.append(field_def)