Classgen: support literal enums

This commit is contained in:
Johannes Kirschbauer
2024-09-11 15:13:56 +02:00
parent cdfda2547b
commit 42b92132a7

View File

@@ -146,13 +146,21 @@ def field_def_from_default_value(
)
if default_value == "name":
return None
# Primitive types
if isinstance(default_value, str):
return finalize_field(
field_types=field_types,
default=f"'{default_value}'",
)
if isinstance(default_value, bool | int | float):
# Bool must be checked before int
return finalize_field(
field_types=field_types,
default=f"{default_value}",
)
# Other default values unhandled yet.
msg = f"Unhandled default value for field '{field_name}' - default value: {default_value}"
msg = f"Unhandled default value for field '{field_name}' - default value: {default_value}. ( In {nested_class_name} )"
raise Error(msg)
@@ -189,13 +197,15 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
required_fields = []
fields_with_default = []
nested_classes = []
nested_classes: list[str] = []
for prop, prop_info in properties.items():
field_name = prop.replace("-", "_")
prop_type = prop_info.get("type", None)
union_variants = prop_info.get("oneOf", [])
enum_variants = prop_info.get("enum", [])
# Collect all types
field_types = set()
@@ -203,7 +213,7 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
title_sanitized = "".join([p.capitalize() for p in title.split("-")])
nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}"""
if (prop_type is None) and (not union_variants):
if not prop_type and not union_variants and not enum_variants:
msg = f"Type not found for property {prop} {prop_info}"
raise Error(msg)
@@ -217,6 +227,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
field_types = map_json_type(
prop_type, map_json_type(item_schema), field_name
)
elif enum := prop_info.get("enum"):
literals = ", ".join([f'"{string}"' for string in enum])
field_types = {f"""Literal[{literals}]"""}
elif prop_type == "object":
inner_type = prop_info.get("additionalProperties")
@@ -224,7 +237,6 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
# Inner type is a class
field_types = map_json_type(prop_type, {nested_class_name}, field_name)
#
if nested_class_name not in known_classes:
nested_classes.append(
generate_dataclass(inner_type, nested_class_name)
@@ -260,6 +272,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
field_meta = f"""{{"alias": "{prop}"}}"""
finalize_field = partial(get_field_def, field_name, field_meta)
# if class_name == "DyndnsConfig":
# if class_name == "ServiceDyndnMachine":
# breakpoint()
if "default" in prop_info or field_name not in prop_info.get("required", []):
if "default" in prop_info:
@@ -308,7 +323,12 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
fields_str = "\n ".join(required_fields + fields_with_default)
nested_classes_str = "\n\n".join(nested_classes)
class_def = f"@dataclass\nclass {class_name}:\n {fields_str}\n"
class_def = f"@dataclass\nclass {class_name}:\n"
if not required_fields + fields_with_default:
class_def += " pass"
else:
class_def += f" {fields_str}\n"
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def
@@ -328,7 +348,7 @@ def run_gen(args: argparse.Namespace) -> None:
# ruff: noqa: F401
# fmt: off
from dataclasses import dataclass, field
from typing import Any\n\n
from typing import Any, Literal\n\n
"""
)
f.write(dataclass_code)