diff --git a/pkgs/classgen/main.py b/pkgs/classgen/main.py index 76bb115f0..856b1c545 100644 --- a/pkgs/classgen/main.py +++ b/pkgs/classgen/main.py @@ -65,8 +65,8 @@ def field_def_from_default_type( field_name: str, field_types: set[str], class_name: str, - finalize_field: Callable[..., str], -) -> str | None: + finalize_field: Callable[..., tuple[str, str]], +) -> tuple[str, str] | None: if "dict" in str(field_types): return finalize_field( field_types=field_types, @@ -127,8 +127,8 @@ def field_def_from_default_value( field_name: str, field_types: set[str], nested_class_name: str, - finalize_field: Callable[..., str], -) -> str | None: + finalize_field: Callable[..., tuple[str, str]], +) -> tuple[str, str] | None: # default_value = prop_info.get("default") if default_value is None: return finalize_field( @@ -184,7 +184,7 @@ def get_field_def( default: str | None = None, default_factory: str | None = None, type_appendix: str = "", -) -> str: +) -> tuple[str, str]: if "None" in field_types or default or default_factory: if "None" in field_types: field_types.remove("None") @@ -193,7 +193,7 @@ def get_field_def( else: serialised_types = " | ".join(field_types) + type_appendix - return f"{field_name}: {serialised_types}" + return (field_name, serialised_types) # Recursive function to generate dataclasses from JSON schema @@ -204,8 +204,8 @@ def generate_dataclass( ) -> str: properties = schema.get("properties", {}) - required_fields = [] - fields_with_default = [] + required_fields: list[tuple[str, str]] = [] + fields_with_default: list[tuple[str, str]] = [] nested_classes: list[str] = [] # if We are at the top level, and the attribute name is in shallow @@ -218,7 +218,7 @@ def generate_dataclass( field_name = prop.replace("-", "_") if len(attr_path) == 0 and prop not in attrs: - field_def = f"{field_name}: NotRequired[dict[str, Any]]" + field_def = field_name, "NotRequired[dict[str, Any]]" fields_with_default.append(field_def) # breakpoint() continue @@ -345,9 +345,18 @@ def generate_dataclass( ) required_fields.append(field_def) - # breakpoint() - - fields_str = "\n ".join(required_fields + fields_with_default) + # Join field name with type to form a complete field declaration + # e.g. "name: str" + all_field_declarations = [ + f"{n}: {t}" for n, t in (required_fields + fields_with_default) + ] + hoisted_types: str = "\n".join( + [ + f"{class_name}{n.capitalize()}Type = {x}" + for n, x in (required_fields + fields_with_default) + ] + ) + fields_str = "\n ".join(all_field_declarations) nested_classes_str = "\n\n".join(nested_classes) class_def = f"\nclass {class_name}(TypedDict):\n" @@ -356,6 +365,8 @@ def generate_dataclass( else: class_def += f" {fields_str}" + class_def += f"\n\n{hoisted_types}\n" + return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def