classgen: produce typedDict instead of dataclass

This commit is contained in:
Johannes Kirschbauer
2024-12-06 17:47:50 +01:00
parent 8fd4d82f1d
commit f018f4e68e

View File

@@ -148,8 +148,7 @@ def field_def_from_default_value(
default_factory="dict", default_factory="dict",
type_appendix=" | dict[str,Any]", type_appendix=" | dict[str,Any]",
) )
if default_value == "name":
return None
# Primitive types # Primitive types
if isinstance(default_value, str): if isinstance(default_value, str):
return finalize_field( return finalize_field(
@@ -176,23 +175,15 @@ def get_field_def(
default_factory: str | None = None, default_factory: str | None = None,
type_appendix: str = "", type_appendix: str = "",
) -> str: ) -> str:
sorted_field_types = sorted(field_types) if "None" in field_types or default or default_factory:
serialised_types = " | ".join(sorted_field_types) + type_appendix if "None" in field_types:
if not default and not default_factory and not field_meta: field_types.remove("None")
return f"{field_name}: {serialised_types}" serialised_types = " | ".join(field_types) + type_appendix
field_init = "field(" serialised_types = f"NotRequired[{serialised_types}]"
else:
serialised_types = " | ".join(field_types) + type_appendix
init_args = [] return f"{field_name}: {serialised_types}"
if default:
init_args.append(f"default = {default}")
if default_factory:
init_args.append(f"default_factory = {default_factory}")
if field_meta:
init_args.append(f"metadata = {field_meta}")
field_init += ", ".join(init_args) + ")"
return f"{field_name}: {serialised_types} = {field_init}"
# Recursive function to generate dataclasses from JSON schema # Recursive function to generate dataclasses from JSON schema
@@ -281,6 +272,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
finalize_field = partial(get_field_def, field_name, field_meta) finalize_field = partial(get_field_def, field_name, field_meta)
if "default" in prop_info or field_name not in prop_info.get("required", []): if "default" in prop_info or field_name not in prop_info.get("required", []):
if prop_info.get("type") == "object":
prop_info.update({"default": {}})
if "default" in prop_info: if "default" in prop_info:
default_value = prop_info.get("default") default_value = prop_info.get("default")
field_def = field_def_from_default_value( field_def = field_def_from_default_value(
@@ -327,11 +321,11 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
fields_str = "\n ".join(required_fields + fields_with_default) fields_str = "\n ".join(required_fields + fields_with_default)
nested_classes_str = "\n\n".join(nested_classes) nested_classes_str = "\n\n".join(nested_classes)
class_def = f"@dataclass\nclass {class_name}:\n" class_def = f"\nclass {class_name}(TypedDict):\n"
if not required_fields + fields_with_default: if not required_fields + fields_with_default:
class_def += " pass" class_def += " pass"
else: else:
class_def += f" {fields_str}\n" class_def += f" {fields_str}"
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def
@@ -356,8 +350,7 @@ def run_gen(args: argparse.Namespace) -> None:
# ruff: noqa: N806 # ruff: noqa: N806
# ruff: noqa: F401 # ruff: noqa: F401
# fmt: off # fmt: off
from dataclasses import dataclass, field from typing import Any, Literal, TypedDict, NotRequired\n
from typing import Any, Literal\n\n
""" """
) )
f.write(dataclass_code) f.write(dataclass_code)