feat(classgen): defer NotRequired type wrapper

This commit is contained in:
Johannes Kirschbauer
2025-05-20 15:51:10 +02:00
parent 9eeae6e229
commit c86f39ba6b

View File

@@ -198,7 +198,7 @@ def get_field_def(
if "None" in field_types: if "None" in field_types:
field_types.remove("None") field_types.remove("None")
serialised_types = " | ".join(field_types) + type_appendix serialised_types = " | ".join(field_types) + type_appendix
serialised_types = f"NotRequired[{serialised_types}]" serialised_types = f"{serialised_types}"
else: else:
serialised_types = " | ".join(field_types) + type_appendix serialised_types = " | ".join(field_types) + type_appendix
@@ -227,7 +227,7 @@ def generate_dataclass(
field_name = prop.replace("-", "_") field_name = prop.replace("-", "_")
if len(attr_path) == 0 and prop not in attrs: if len(attr_path) == 0 and prop not in attrs:
field_def = field_name, "NotRequired[dict[str, Any]]" field_def = field_name, "dict[str, Any]"
fields_with_default.append(field_def) fields_with_default.append(field_def)
# breakpoint() # breakpoint()
continue continue
@@ -244,8 +244,6 @@ def generate_dataclass(
nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}""" nested_class_name = f"""{class_name if class_name != root_class and not prop_info.get("title") else ""}{title_sanitized}"""
if not prop_type and not union_variants and not enum_variants: if not prop_type and not union_variants and not enum_variants:
msg = f"Type not found for property {prop} {prop_info}"
raise Error(msg)
msg = f"Type not found for property {prop} {prop_info}.\nConverting to unknown type.\n" msg = f"Type not found for property {prop} {prop_info}.\nConverting to unknown type.\n"
logger.warning(msg) logger.warning(msg)
prop_type = "Unknown" prop_type = "Unknown"
@@ -361,8 +359,9 @@ def generate_dataclass(
# Join field name with type to form a complete field declaration # Join field name with type to form a complete field declaration
# e.g. "name: str" # e.g. "name: str"
all_field_declarations = [ all_field_declarations = [f"{n}: {t}" for n, t in (required_fields)] + [
f"{n}: {t}" for n, t in (required_fields + fields_with_default) f"{n}: NotRequired[{class_name}{n.capitalize()}Type]"
for n, t in (fields_with_default)
] ]
hoisted_types: str = "\n".join( hoisted_types: str = "\n".join(
[ [
@@ -373,14 +372,13 @@ def generate_dataclass(
fields_str = "\n ".join(all_field_declarations) fields_str = "\n ".join(all_field_declarations)
nested_classes_str = "\n\n".join(nested_classes) nested_classes_str = "\n\n".join(nested_classes)
class_def = f"\nclass {class_name}(TypedDict):\n" class_def = f"\n\n{hoisted_types}\n"
class_def += f"\nclass {class_name}(TypedDict):\n"
if not required_fields + fields_with_default: if not required_fields + fields_with_default:
class_def += " pass" class_def += " pass"
else: else:
class_def += f" {fields_str}" class_def += f" {fields_str}"
class_def += f"\n\n{hoisted_types}\n"
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def