Classgen: export field type definitions
This commit is contained in:
@@ -65,8 +65,8 @@ def field_def_from_default_type(
|
||||
field_name: str,
|
||||
field_types: set[str],
|
||||
class_name: str,
|
||||
finalize_field: Callable[..., str],
|
||||
) -> str | None:
|
||||
finalize_field: Callable[..., tuple[str, str]],
|
||||
) -> tuple[str, str] | None:
|
||||
if "dict" in str(field_types):
|
||||
return finalize_field(
|
||||
field_types=field_types,
|
||||
@@ -127,8 +127,8 @@ def field_def_from_default_value(
|
||||
field_name: str,
|
||||
field_types: set[str],
|
||||
nested_class_name: str,
|
||||
finalize_field: Callable[..., str],
|
||||
) -> str | None:
|
||||
finalize_field: Callable[..., tuple[str, str]],
|
||||
) -> tuple[str, str] | None:
|
||||
# default_value = prop_info.get("default")
|
||||
if default_value is None:
|
||||
return finalize_field(
|
||||
@@ -184,7 +184,7 @@ def get_field_def(
|
||||
default: str | None = None,
|
||||
default_factory: str | None = None,
|
||||
type_appendix: str = "",
|
||||
) -> str:
|
||||
) -> tuple[str, str]:
|
||||
if "None" in field_types or default or default_factory:
|
||||
if "None" in field_types:
|
||||
field_types.remove("None")
|
||||
@@ -193,7 +193,7 @@ def get_field_def(
|
||||
else:
|
||||
serialised_types = " | ".join(field_types) + type_appendix
|
||||
|
||||
return f"{field_name}: {serialised_types}"
|
||||
return (field_name, serialised_types)
|
||||
|
||||
|
||||
# Recursive function to generate dataclasses from JSON schema
|
||||
@@ -204,8 +204,8 @@ def generate_dataclass(
|
||||
) -> str:
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
required_fields = []
|
||||
fields_with_default = []
|
||||
required_fields: list[tuple[str, str]] = []
|
||||
fields_with_default: list[tuple[str, str]] = []
|
||||
nested_classes: list[str] = []
|
||||
|
||||
# if We are at the top level, and the attribute name is in shallow
|
||||
@@ -218,7 +218,7 @@ def generate_dataclass(
|
||||
field_name = prop.replace("-", "_")
|
||||
|
||||
if len(attr_path) == 0 and prop not in attrs:
|
||||
field_def = f"{field_name}: NotRequired[dict[str, Any]]"
|
||||
field_def = field_name, "NotRequired[dict[str, Any]]"
|
||||
fields_with_default.append(field_def)
|
||||
# breakpoint()
|
||||
continue
|
||||
@@ -345,9 +345,18 @@ def generate_dataclass(
|
||||
)
|
||||
required_fields.append(field_def)
|
||||
|
||||
# breakpoint()
|
||||
|
||||
fields_str = "\n ".join(required_fields + fields_with_default)
|
||||
# Join field name with type to form a complete field declaration
|
||||
# e.g. "name: str"
|
||||
all_field_declarations = [
|
||||
f"{n}: {t}" for n, t in (required_fields + fields_with_default)
|
||||
]
|
||||
hoisted_types: str = "\n".join(
|
||||
[
|
||||
f"{class_name}{n.capitalize()}Type = {x}"
|
||||
for n, x in (required_fields + fields_with_default)
|
||||
]
|
||||
)
|
||||
fields_str = "\n ".join(all_field_declarations)
|
||||
nested_classes_str = "\n\n".join(nested_classes)
|
||||
|
||||
class_def = f"\nclass {class_name}(TypedDict):\n"
|
||||
@@ -356,6 +365,8 @@ def generate_dataclass(
|
||||
else:
|
||||
class_def += f" {fields_str}"
|
||||
|
||||
class_def += f"\n\n{hoisted_types}\n"
|
||||
|
||||
return f"{nested_classes_str}\n\n{class_def}" if nested_classes_str else class_def
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user