Merge pull request 'Classgen: support literal enums' (#2068) from hsjobeki/clan-core:hsjobeki-main into main

This commit is contained in:
clan-bot
2024-09-12 07:33:27 +00:00
3 changed files with 28 additions and 8 deletions

View File

@@ -5,7 +5,7 @@
# ruff: noqa: F401 # ruff: noqa: F401
# fmt: off # fmt: off
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any, Literal
@dataclass @dataclass

View File

@@ -132,7 +132,7 @@
echo "Classes file is up to date" echo "Classes file is up to date"
else else
echo "Classes file is out of date or has been modified" echo "Classes file is out of date or has been modified"
echo "run ./update.sh in the inventory directory to update the classes file" echo "run 'direnv reload' in the pkgs/clan-cli directory to refresh the classes file"
echo "--------------------------------\n" echo "--------------------------------\n"
diff "$file1" "$file2" diff "$file1" "$file2"
echo "--------------------------------\n\n" echo "--------------------------------\n\n"

View File

@@ -146,13 +146,21 @@ def field_def_from_default_value(
) )
if default_value == "name": if default_value == "name":
return None return None
# Primitive types
if isinstance(default_value, str): if isinstance(default_value, str):
return finalize_field( return finalize_field(
field_types=field_types, field_types=field_types,
default=f"'{default_value}'", default=f"'{default_value}'",
) )
if isinstance(default_value, bool | int | float):
# Bool must be checked before int
return finalize_field(
field_types=field_types,
default=f"{default_value}",
)
# Other default values unhandled yet. # Other default values unhandled yet.
msg = f"Unhandled default value for field '{field_name}' - default value: {default_value}" msg = f"Unhandled default value for field '{field_name}' - default value: {default_value}. ( In {nested_class_name} )"
raise Error(msg) raise Error(msg)
@@ -189,13 +197,15 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
required_fields = [] required_fields = []
fields_with_default = [] fields_with_default = []
nested_classes = [] nested_classes: list[str] = []
for prop, prop_info in properties.items(): for prop, prop_info in properties.items():
field_name = prop.replace("-", "_") field_name = prop.replace("-", "_")
prop_type = prop_info.get("type", None) prop_type = prop_info.get("type", None)
union_variants = prop_info.get("oneOf", []) union_variants = prop_info.get("oneOf", [])
enum_variants = prop_info.get("enum", [])
# Collect all types # Collect all types
field_types = set() field_types = set()
@@ -203,7 +213,7 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
title_sanitized = "".join([p.capitalize() for p in title.split("-")]) title_sanitized = "".join([p.capitalize() for p in title.split("-")])
nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}""" nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}"""
if (prop_type is None) and (not union_variants): if not prop_type and not union_variants and not enum_variants:
msg = f"Type not found for property {prop} {prop_info}" msg = f"Type not found for property {prop} {prop_info}"
raise Error(msg) raise Error(msg)
@@ -217,6 +227,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
field_types = map_json_type( field_types = map_json_type(
prop_type, map_json_type(item_schema), field_name prop_type, map_json_type(item_schema), field_name
) )
elif enum := prop_info.get("enum"):
literals = ", ".join([f'"{string}"' for string in enum])
field_types = {f"""Literal[{literals}]"""}
elif prop_type == "object": elif prop_type == "object":
inner_type = prop_info.get("additionalProperties") inner_type = prop_info.get("additionalProperties")
@@ -224,7 +237,6 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
# Inner type is a class # Inner type is a class
field_types = map_json_type(prop_type, {nested_class_name}, field_name) field_types = map_json_type(prop_type, {nested_class_name}, field_name)
#
if nested_class_name not in known_classes: if nested_class_name not in known_classes:
nested_classes.append( nested_classes.append(
generate_dataclass(inner_type, nested_class_name) generate_dataclass(inner_type, nested_class_name)
@@ -260,6 +272,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
field_meta = f"""{{"alias": "{prop}"}}""" field_meta = f"""{{"alias": "{prop}"}}"""
finalize_field = partial(get_field_def, field_name, field_meta) finalize_field = partial(get_field_def, field_name, field_meta)
# if class_name == "DyndnsConfig":
# if class_name == "ServiceDyndnMachine":
# breakpoint()
if "default" in prop_info or field_name not in prop_info.get("required", []): if "default" in prop_info or field_name not in prop_info.get("required", []):
if "default" in prop_info: if "default" in prop_info:
@@ -308,7 +323,12 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
fields_str = "\n ".join(required_fields + fields_with_default) fields_str = "\n ".join(required_fields + fields_with_default)
nested_classes_str = "\n\n".join(nested_classes) nested_classes_str = "\n\n".join(nested_classes)
class_def = f"@dataclass\nclass {class_name}:\n {fields_str}\n" class_def = f"@dataclass\nclass {class_name}:\n"
if not required_fields + fields_with_default:
class_def += " pass"
else:
class_def += f" {fields_str}\n"
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def
@@ -328,7 +348,7 @@ def run_gen(args: argparse.Namespace) -> None:
# ruff: noqa: F401 # ruff: noqa: F401
# fmt: off # fmt: off
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any\n\n from typing import Any, Literal\n\n
""" """
) )
f.write(dataclass_code) f.write(dataclass_code)