diff --git a/pkgs/classgen/main.py b/pkgs/classgen/main.py index 8cafe0b6c..c698c480a 100644 --- a/pkgs/classgen/main.py +++ b/pkgs/classgen/main.py @@ -165,14 +165,18 @@ def get_field_def( if not default and not default_factory and not field_meta: return f"{field_name}: {serialised_types}" field_init = "field(" - if default: - field_init += f"default = {default}" - if default_factory: - field_init += f"default_factory = {default_factory}" - if field_meta: - field_init += f", metadata = {field_meta}" - return f"{field_name}: {serialised_types} = {field_init})" + init_args = [] + if default: + init_args.append(f"default = {default}") + if default_factory: + init_args.append(f"default_factory = {default_factory}") + if field_meta: + init_args.append(f"metadata = {field_meta}") + + field_init += ", ".join(init_args) + ")" + + return f"{field_name}: {serialised_types} = {field_init}" # Recursive function to generate dataclasses from JSON schema @@ -223,8 +227,11 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) -> known_classes.add(nested_class_name) elif inner_type and inner_type.get("type") != "object": - # Trivial type - field_types = map_json_type(inner_type) + # Trivial type: + # dict[str, inner_type] + field_types = { + f"""dict[str, {" | ".join(map_json_type(inner_type))}]""" + } elif not inner_type: # The type is a class