classgen: produce typedDict instead of dataclass

This commit is contained in:
Johannes Kirschbauer
2024-12-06 17:47:50 +01:00
parent 8fd4d82f1d
commit f018f4e68e

View File

@@ -148,8 +148,7 @@ def field_def_from_default_value(
default_factory="dict",
type_appendix=" | dict[str,Any]",
)
if default_value == "name":
return None
# Primitive types
if isinstance(default_value, str):
return finalize_field(
@@ -176,23 +175,15 @@ def get_field_def(
default_factory: str | None = None,
type_appendix: str = "",
) -> str:
sorted_field_types = sorted(field_types)
serialised_types = " | ".join(sorted_field_types) + type_appendix
if not default and not default_factory and not field_meta:
return f"{field_name}: {serialised_types}"
field_init = "field("
if "None" in field_types or default or default_factory:
if "None" in field_types:
field_types.remove("None")
serialised_types = " | ".join(field_types) + type_appendix
serialised_types = f"NotRequired[{serialised_types}]"
else:
serialised_types = " | ".join(field_types) + type_appendix
init_args = []
if default:
init_args.append(f"default = {default}")
if default_factory:
init_args.append(f"default_factory = {default_factory}")
if field_meta:
init_args.append(f"metadata = {field_meta}")
field_init += ", ".join(init_args) + ")"
return f"{field_name}: {serialised_types} = {field_init}"
return f"{field_name}: {serialised_types}"
# Recursive function to generate dataclasses from JSON schema
@@ -281,6 +272,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
finalize_field = partial(get_field_def, field_name, field_meta)
if "default" in prop_info or field_name not in prop_info.get("required", []):
if prop_info.get("type") == "object":
prop_info.update({"default": {}})
if "default" in prop_info:
default_value = prop_info.get("default")
field_def = field_def_from_default_value(
@@ -327,11 +321,11 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
fields_str = "\n ".join(required_fields + fields_with_default)
nested_classes_str = "\n\n".join(nested_classes)
class_def = f"@dataclass\nclass {class_name}:\n"
class_def = f"\nclass {class_name}(TypedDict):\n"
if not required_fields + fields_with_default:
class_def += " pass"
else:
class_def += f" {fields_str}\n"
class_def += f" {fields_str}"
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def
@@ -356,8 +350,7 @@ def run_gen(args: argparse.Namespace) -> None:
# ruff: noqa: N806
# ruff: noqa: F401
# fmt: off
from dataclasses import dataclass, field
from typing import Any, Literal\n\n
from typing import Any, Literal, TypedDict, NotRequired\n
"""
)
f.write(dataclass_code)