diff --git a/pkgs/classgen/main.py b/pkgs/classgen/main.py index fee847ddf..fffcf1155 100644 --- a/pkgs/classgen/main.py +++ b/pkgs/classgen/main.py @@ -148,8 +148,7 @@ def field_def_from_default_value( default_factory="dict", type_appendix=" | dict[str,Any]", ) - if default_value == "‹name›": - return None + # Primitive types if isinstance(default_value, str): return finalize_field( @@ -176,23 +175,15 @@ def get_field_def( default_factory: str | None = None, type_appendix: str = "", ) -> str: - sorted_field_types = sorted(field_types) - serialised_types = " | ".join(sorted_field_types) + type_appendix - if not default and not default_factory and not field_meta: - return f"{field_name}: {serialised_types}" - field_init = "field(" + if "None" in field_types or default or default_factory: + if "None" in field_types: + field_types.remove("None") + serialised_types = " | ".join(field_types) + type_appendix + serialised_types = f"NotRequired[{serialised_types}]" + else: + serialised_types = " | ".join(field_types) + type_appendix - init_args = [] - if default: - init_args.append(f"default = {default}") - if default_factory: - init_args.append(f"default_factory = {default_factory}") - if field_meta: - init_args.append(f"metadata = {field_meta}") - - field_init += ", ".join(init_args) + ")" - - return f"{field_name}: {serialised_types} = {field_init}" + return f"{field_name}: {serialised_types}" # Recursive function to generate dataclasses from JSON schema @@ -281,6 +272,9 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) -> finalize_field = partial(get_field_def, field_name, field_meta) if "default" in prop_info or field_name not in prop_info.get("required", []): + if prop_info.get("type") == "object": + prop_info.update({"default": {}}) + if "default" in prop_info: default_value = prop_info.get("default") field_def = field_def_from_default_value( @@ -327,11 +321,11 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) -> fields_str = "\n ".join(required_fields + fields_with_default) nested_classes_str = "\n\n".join(nested_classes) - class_def = f"@dataclass\nclass {class_name}:\n" + class_def = f"\nclass {class_name}(TypedDict):\n" if not required_fields + fields_with_default: class_def += " pass" else: - class_def += f" {fields_str}\n" + class_def += f" {fields_str}" return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def @@ -356,8 +350,7 @@ def run_gen(args: argparse.Namespace) -> None: # ruff: noqa: N806 # ruff: noqa: F401 # fmt: off -from dataclasses import dataclass, field -from typing import Any, Literal\n\n +from typing import Any, Literal, TypedDict, NotRequired\n """ ) f.write(dataclass_code)