feat(classgen): make type generation more predictable across
This commit is contained in:
@@ -3,7 +3,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -15,58 +15,70 @@ class Error(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def sort_types(items: Iterable[str]) -> list[str]:
|
||||||
|
def sort_key(item: str) -> tuple[int, str]:
|
||||||
|
# Priority order: lower number = higher priority
|
||||||
|
if item.startswith(("dict", "list")):
|
||||||
|
return (0, item) # Highest priority, dicts and lists should be first
|
||||||
|
if item == "None":
|
||||||
|
return (2, item) # Lowest priority, None should be last
|
||||||
|
return (1, item) # Middle priority, sorted alphabetically
|
||||||
|
|
||||||
|
return sorted(items, key=sort_key)
|
||||||
|
|
||||||
|
|
||||||
# Function to map JSON schemas and types to Python types
|
# Function to map JSON schemas and types to Python types
|
||||||
def map_json_type(
|
def map_json_type(
|
||||||
json_type: Any, nested_types: set[str] | None = None, parent: Any = None
|
json_type: Any, nested_types: list[str] | None = None, parent: Any = None
|
||||||
) -> set[str]:
|
) -> list[str]:
|
||||||
if nested_types is None:
|
if nested_types is None:
|
||||||
nested_types = {"Any"}
|
nested_types = ["Any"]
|
||||||
if isinstance(json_type, list):
|
if isinstance(json_type, list):
|
||||||
res = set()
|
res: list[str] = []
|
||||||
for t in json_type:
|
for t in json_type:
|
||||||
res |= map_json_type(t)
|
res.extend(map_json_type(t))
|
||||||
return res
|
return sort_types(set(res))
|
||||||
if isinstance(json_type, dict):
|
if isinstance(json_type, dict):
|
||||||
items = json_type.get("items")
|
items = json_type.get("items")
|
||||||
if items:
|
if items:
|
||||||
nested_types = map_json_type(items)
|
nested_types = map_json_type(items)
|
||||||
return map_json_type(json_type.get("type"), nested_types)
|
|
||||||
|
if not json_type.get("type") and json_type.get("tsType") == "unknown":
|
||||||
|
return ["Unknown"]
|
||||||
|
|
||||||
|
return sort_types(map_json_type(json_type.get("type"), nested_types))
|
||||||
if json_type == "string":
|
if json_type == "string":
|
||||||
return {"str"}
|
return ["str"]
|
||||||
if json_type == "integer":
|
if json_type == "integer":
|
||||||
return {"int"}
|
return ["int"]
|
||||||
if json_type == "number":
|
if json_type == "number":
|
||||||
return {"float"}
|
return ["float"]
|
||||||
if json_type == "boolean":
|
if json_type == "boolean":
|
||||||
return {"bool"}
|
return ["bool"]
|
||||||
# In Python, "number" is analogous to the float type.
|
# In Python, "number" is analogous to the float type.
|
||||||
# https://json-schema.org/understanding-json-schema/reference/numeric#number
|
# https://json-schema.org/understanding-json-schema/reference/numeric#number
|
||||||
if json_type == "number":
|
if json_type == "number":
|
||||||
return {"float"}
|
return ["float"]
|
||||||
if json_type == "array":
|
if json_type == "array":
|
||||||
assert nested_types, f"Array type not found for {parent}"
|
assert nested_types, f"Array type not found for {parent}"
|
||||||
return {f"""list[{" | ".join(nested_types)}]"""}
|
return [f"""list[{" | ".join(sort_types(nested_types))}]"""]
|
||||||
if json_type == "object":
|
if json_type == "object":
|
||||||
assert nested_types, f"dict type not found for {parent}"
|
assert nested_types, f"dict type not found for {parent}"
|
||||||
return {f"""dict[str, {" | ".join(nested_types)}]"""}
|
return [f"""dict[str, {" | ".join(sort_types(nested_types))}]"""]
|
||||||
if json_type == "null":
|
if json_type == "null":
|
||||||
return {"None"}
|
return ["None"]
|
||||||
msg = f"Python type not found for {json_type}"
|
msg = f"Python type not found for {json_type}"
|
||||||
raise Error(msg)
|
raise Error(msg)
|
||||||
|
|
||||||
|
|
||||||
known_classes = set()
|
known_classes = set()
|
||||||
root_class = "Inventory"
|
|
||||||
# TODO: make this configurable
|
# TODO: make this configurable
|
||||||
# For now this only includes static top-level attributes of the inventory.
|
root_class = "Clan"
|
||||||
attrs = ["machines", "meta", "services", "instances"]
|
|
||||||
|
|
||||||
static: dict[str, str] = {"Service": "dict[str, Any]"}
|
|
||||||
|
|
||||||
|
|
||||||
def field_def_from_default_type(
|
def field_def_from_default_type(
|
||||||
field_name: str,
|
field_name: str,
|
||||||
field_types: set[str],
|
field_types: list[str],
|
||||||
class_name: str,
|
class_name: str,
|
||||||
finalize_field: Callable[..., tuple[str, str]],
|
finalize_field: Callable[..., tuple[str, str]],
|
||||||
) -> tuple[str, str] | None:
|
) -> tuple[str, str] | None:
|
||||||
@@ -128,7 +140,7 @@ def field_def_from_default_type(
|
|||||||
def field_def_from_default_value(
|
def field_def_from_default_value(
|
||||||
default_value: Any,
|
default_value: Any,
|
||||||
field_name: str,
|
field_name: str,
|
||||||
field_types: set[str],
|
field_types: list[str],
|
||||||
nested_class_name: str,
|
nested_class_name: str,
|
||||||
finalize_field: Callable[..., tuple[str, str]],
|
finalize_field: Callable[..., tuple[str, str]],
|
||||||
) -> tuple[str, str] | None:
|
) -> tuple[str, str] | None:
|
||||||
@@ -141,7 +153,7 @@ def field_def_from_default_value(
|
|||||||
)
|
)
|
||||||
if default_value is None:
|
if default_value is None:
|
||||||
return finalize_field(
|
return finalize_field(
|
||||||
field_types=field_types | {"None"},
|
field_types=[*field_types, "None"],
|
||||||
default="None",
|
default="None",
|
||||||
)
|
)
|
||||||
if isinstance(default_value, list):
|
if isinstance(default_value, list):
|
||||||
@@ -189,7 +201,7 @@ def field_def_from_default_value(
|
|||||||
def get_field_def(
|
def get_field_def(
|
||||||
field_name: str,
|
field_name: str,
|
||||||
field_meta: str | None,
|
field_meta: str | None,
|
||||||
field_types: set[str],
|
field_types: list[str],
|
||||||
default: str | None = None,
|
default: str | None = None,
|
||||||
default_factory: str | None = None,
|
default_factory: str | None = None,
|
||||||
type_appendix: str = "",
|
type_appendix: str = "",
|
||||||
@@ -197,10 +209,10 @@ def get_field_def(
|
|||||||
if "None" in field_types or default or default_factory:
|
if "None" in field_types or default or default_factory:
|
||||||
if "None" in field_types:
|
if "None" in field_types:
|
||||||
field_types.remove("None")
|
field_types.remove("None")
|
||||||
serialised_types = " | ".join(field_types) + type_appendix
|
serialised_types = " | ".join(sort_types(field_types)) + type_appendix
|
||||||
serialised_types = f"{serialised_types}"
|
serialised_types = f"{serialised_types}"
|
||||||
else:
|
else:
|
||||||
serialised_types = " | ".join(field_types) + type_appendix
|
serialised_types = " | ".join(sort_types(field_types)) + type_appendix
|
||||||
|
|
||||||
return (field_name, serialised_types)
|
return (field_name, serialised_types)
|
||||||
|
|
||||||
@@ -217,26 +229,21 @@ def generate_dataclass(
|
|||||||
fields_with_default: list[tuple[str, str]] = []
|
fields_with_default: list[tuple[str, str]] = []
|
||||||
nested_classes: list[str] = []
|
nested_classes: list[str] = []
|
||||||
|
|
||||||
# if We are at the top level, and the attribute name is in shallow
|
|
||||||
# return f"{class_name} = dict[str, Any]"
|
|
||||||
if class_name in static:
|
|
||||||
return f"{class_name} = {static[class_name]}"
|
|
||||||
|
|
||||||
for prop, prop_info in properties.items():
|
for prop, prop_info in properties.items():
|
||||||
# If we are at the top level, and the attribute name is not explicitly included we only do shallow
|
# If we are at the top level, and the attribute name is not explicitly included we only do shallow
|
||||||
field_name = prop.replace("-", "_")
|
field_name = prop.replace("-", "_")
|
||||||
|
|
||||||
if len(attr_path) == 0 and prop not in attrs:
|
# if len(attr_path) == 0 and prop in shallow_attrs:
|
||||||
field_def = field_name, "dict[str, Any]"
|
# field_def = field_name, "dict[str, Any]"
|
||||||
fields_with_default.append(field_def)
|
# fields_with_default.append(field_def)
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
prop_type = prop_info.get("type", None)
|
prop_type = prop_info.get("type", None)
|
||||||
union_variants = prop_info.get("oneOf", [])
|
union_variants = prop_info.get("oneOf", [])
|
||||||
enum_variants = prop_info.get("enum", [])
|
enum_variants = prop_info.get("enum", [])
|
||||||
|
|
||||||
# Collect all types
|
# Collect all types
|
||||||
field_types = set()
|
field_types: list[str] = []
|
||||||
|
|
||||||
title = prop_info.get("title", prop.removesuffix("s"))
|
title = prop_info.get("title", prop.removesuffix("s"))
|
||||||
title_sanitized = "".join([p.capitalize() for p in title.split("-")])
|
title_sanitized = "".join([p.capitalize() for p in title.split("-")])
|
||||||
@@ -259,13 +266,13 @@ def generate_dataclass(
|
|||||||
)
|
)
|
||||||
elif enum := prop_info.get("enum"):
|
elif enum := prop_info.get("enum"):
|
||||||
literals = ", ".join([f'"{string}"' for string in enum])
|
literals = ", ".join([f'"{string}"' for string in enum])
|
||||||
field_types = {f"""Literal[{literals}]"""}
|
field_types = [f"""Literal[{literals}]"""]
|
||||||
|
|
||||||
elif prop_type == "object":
|
elif prop_type == "object":
|
||||||
inner_type = prop_info.get("additionalProperties")
|
inner_type = prop_info.get("additionalProperties")
|
||||||
if inner_type and inner_type.get("type") == "object":
|
if inner_type and inner_type.get("type") == "object":
|
||||||
# Inner type is a class
|
# Inner type is a class
|
||||||
field_types = map_json_type(prop_type, {nested_class_name}, field_name)
|
field_types = map_json_type(prop_type, [nested_class_name], field_name)
|
||||||
|
|
||||||
if nested_class_name not in known_classes:
|
if nested_class_name not in known_classes:
|
||||||
nested_classes.append(
|
nested_classes.append(
|
||||||
@@ -278,13 +285,13 @@ def generate_dataclass(
|
|||||||
elif inner_type and inner_type.get("type") != "object":
|
elif inner_type and inner_type.get("type") != "object":
|
||||||
# Trivial type:
|
# Trivial type:
|
||||||
# dict[str, inner_type]
|
# dict[str, inner_type]
|
||||||
field_types = {
|
field_types = [
|
||||||
f"""dict[str, {" | ".join(map_json_type(inner_type))}]"""
|
f"""dict[str, {" | ".join(map_json_type(inner_type))}]"""
|
||||||
}
|
]
|
||||||
|
|
||||||
elif not inner_type:
|
elif not inner_type:
|
||||||
# The type is a class
|
# The type is a class
|
||||||
field_types = {nested_class_name}
|
field_types = [nested_class_name]
|
||||||
if nested_class_name not in known_classes:
|
if nested_class_name not in known_classes:
|
||||||
nested_classes.append(
|
nested_classes.append(
|
||||||
generate_dataclass(
|
generate_dataclass(
|
||||||
@@ -293,11 +300,11 @@ def generate_dataclass(
|
|||||||
)
|
)
|
||||||
known_classes.add(nested_class_name)
|
known_classes.add(nested_class_name)
|
||||||
elif prop_type == "Unknown":
|
elif prop_type == "Unknown":
|
||||||
field_types = {"Unknown"}
|
field_types = ["Unknown"]
|
||||||
else:
|
else:
|
||||||
field_types = map_json_type(
|
field_types = map_json_type(
|
||||||
prop_type,
|
prop_type,
|
||||||
nested_types=set(),
|
nested_types=[],
|
||||||
parent=field_name,
|
parent=field_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -309,6 +316,9 @@ def generate_dataclass(
|
|||||||
|
|
||||||
finalize_field = partial(get_field_def, field_name, field_meta)
|
finalize_field = partial(get_field_def, field_name, field_meta)
|
||||||
|
|
||||||
|
# Sort and remove potential duplicates
|
||||||
|
field_types_sorted = sort_types(set(field_types))
|
||||||
|
|
||||||
if "default" in prop_info or field_name not in prop_info.get("required", []):
|
if "default" in prop_info or field_name not in prop_info.get("required", []):
|
||||||
if prop_info.get("type") == "object":
|
if prop_info.get("type") == "object":
|
||||||
prop_info.update({"default": {}})
|
prop_info.update({"default": {}})
|
||||||
@@ -318,7 +328,7 @@ def generate_dataclass(
|
|||||||
field_def = field_def_from_default_value(
|
field_def = field_def_from_default_value(
|
||||||
default_value=default_value,
|
default_value=default_value,
|
||||||
field_name=field_name,
|
field_name=field_name,
|
||||||
field_types=field_types,
|
field_types=field_types_sorted,
|
||||||
nested_class_name=nested_class_name,
|
nested_class_name=nested_class_name,
|
||||||
finalize_field=finalize_field,
|
finalize_field=finalize_field,
|
||||||
)
|
)
|
||||||
@@ -328,7 +338,7 @@ def generate_dataclass(
|
|||||||
if not field_def:
|
if not field_def:
|
||||||
# Finalize without the default value
|
# Finalize without the default value
|
||||||
field_def = finalize_field(
|
field_def = finalize_field(
|
||||||
field_types=field_types,
|
field_types=field_types_sorted,
|
||||||
)
|
)
|
||||||
required_fields.append(field_def)
|
required_fields.append(field_def)
|
||||||
|
|
||||||
@@ -337,7 +347,7 @@ def generate_dataclass(
|
|||||||
# Trying to infer default value from type
|
# Trying to infer default value from type
|
||||||
field_def = field_def_from_default_type(
|
field_def = field_def_from_default_type(
|
||||||
field_name=field_name,
|
field_name=field_name,
|
||||||
field_types=field_types,
|
field_types=field_types_sorted,
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
finalize_field=finalize_field,
|
finalize_field=finalize_field,
|
||||||
)
|
)
|
||||||
@@ -346,13 +356,13 @@ def generate_dataclass(
|
|||||||
fields_with_default.append(field_def)
|
fields_with_default.append(field_def)
|
||||||
if not field_def:
|
if not field_def:
|
||||||
field_def = finalize_field(
|
field_def = finalize_field(
|
||||||
field_types=field_types,
|
field_types=field_types_sorted,
|
||||||
)
|
)
|
||||||
required_fields.append(field_def)
|
required_fields.append(field_def)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
field_def = finalize_field(
|
field_def = finalize_field(
|
||||||
field_types=field_types,
|
field_types=field_types_sorted,
|
||||||
)
|
)
|
||||||
required_fields.append(field_def)
|
required_fields.append(field_def)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user