Classgen: support literal enums

This commit is contained in:
Johannes Kirschbauer
2024-09-11 15:13:56 +02:00
parent 8d27e0412d
commit ef18b3e2e9

View File

@@ -146,13 +146,21 @@ def field_def_from_default_value(
) )
if default_value == "name": if default_value == "name":
return None return None
# Primitive types
if isinstance(default_value, str): if isinstance(default_value, str):
return finalize_field( return finalize_field(
field_types=field_types, field_types=field_types,
default=f"'{default_value}'", default=f"'{default_value}'",
) )
if isinstance(default_value, bool | int | float):
# Bool must be checked before int
return finalize_field(
field_types=field_types,
default=f"{default_value}",
)
# Other default values unhandled yet. # Other default values unhandled yet.
msg = f"Unhandled default value for field '{field_name}' - default value: {default_value}" msg = f"Unhandled default value for field '{field_name}' - default value: {default_value}. ( In {nested_class_name} )"
raise Error(msg) raise Error(msg)
@@ -189,13 +197,15 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
required_fields = [] required_fields = []
fields_with_default = [] fields_with_default = []
nested_classes = [] nested_classes: list[str] = []
for prop, prop_info in properties.items(): for prop, prop_info in properties.items():
field_name = prop.replace("-", "_") field_name = prop.replace("-", "_")
prop_type = prop_info.get("type", None) prop_type = prop_info.get("type", None)
union_variants = prop_info.get("oneOf", []) union_variants = prop_info.get("oneOf", [])
enum_variants = prop_info.get("enum", [])
# Collect all types # Collect all types
field_types = set() field_types = set()
@@ -203,7 +213,7 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
title_sanitized = "".join([p.capitalize() for p in title.split("-")]) title_sanitized = "".join([p.capitalize() for p in title.split("-")])
nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}""" nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}"""
if (prop_type is None) and (not union_variants): if not prop_type and not union_variants and not enum_variants:
msg = f"Type not found for property {prop} {prop_info}" msg = f"Type not found for property {prop} {prop_info}"
raise Error(msg) raise Error(msg)
@@ -217,6 +227,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
field_types = map_json_type( field_types = map_json_type(
prop_type, map_json_type(item_schema), field_name prop_type, map_json_type(item_schema), field_name
) )
elif enum := prop_info.get("enum"):
literals = ", ".join([f'"{string}"' for string in enum])
field_types = {f"""Literal[{literals}]"""}
elif prop_type == "object": elif prop_type == "object":
inner_type = prop_info.get("additionalProperties") inner_type = prop_info.get("additionalProperties")
@@ -224,7 +237,6 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
# Inner type is a class # Inner type is a class
field_types = map_json_type(prop_type, {nested_class_name}, field_name) field_types = map_json_type(prop_type, {nested_class_name}, field_name)
#
if nested_class_name not in known_classes: if nested_class_name not in known_classes:
nested_classes.append( nested_classes.append(
generate_dataclass(inner_type, nested_class_name) generate_dataclass(inner_type, nested_class_name)
@@ -260,6 +272,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
field_meta = f"""{{"alias": "{prop}"}}""" field_meta = f"""{{"alias": "{prop}"}}"""
finalize_field = partial(get_field_def, field_name, field_meta) finalize_field = partial(get_field_def, field_name, field_meta)
# if class_name == "DyndnsConfig":
# if class_name == "ServiceDyndnMachine":
# breakpoint()
if "default" in prop_info or field_name not in prop_info.get("required", []): if "default" in prop_info or field_name not in prop_info.get("required", []):
if "default" in prop_info: if "default" in prop_info:
@@ -308,7 +323,12 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
fields_str = "\n ".join(required_fields + fields_with_default) fields_str = "\n ".join(required_fields + fields_with_default)
nested_classes_str = "\n\n".join(nested_classes) nested_classes_str = "\n\n".join(nested_classes)
class_def = f"@dataclass\nclass {class_name}:\n {fields_str}\n" class_def = f"@dataclass\nclass {class_name}:\n"
if not required_fields + fields_with_default:
class_def += " pass"
else:
class_def += f" {fields_str}\n"
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def
@@ -328,7 +348,7 @@ def run_gen(args: argparse.Namespace) -> None:
# ruff: noqa: F401 # ruff: noqa: F401
# fmt: off # fmt: off
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any\n\n from typing import Any, Literal\n\n
""" """
) )
f.write(dataclass_code) f.write(dataclass_code)