From 160fe825764195c814e88b561affe71a0aff9c03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Tue, 1 Oct 2024 15:40:50 +0200 Subject: [PATCH] fix serialisation of SopsKey type --- pkgs/clan-cli/clan_cli/api/serde.py | 17 ++++++++++++----- pkgs/clan-cli/clan_cli/api/util.py | 23 ++++++++++++++++------- pkgs/clan-cli/clan_cli/secrets/sops.py | 4 +--- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index b69f57d3f..61e4f1196 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -32,6 +32,7 @@ Note: This module assumes the presence of other modules and classes such as `Cla import dataclasses import json from dataclasses import dataclass, fields, is_dataclass +from enum import Enum from pathlib import Path from types import UnionType from typing import ( @@ -179,25 +180,31 @@ def construct_value( # Nested types # list # dict - if get_origin(t) is list: + origin = get_origin(t) + if origin is list: if not isinstance(field_value, list): msg = f"Expected list, got {field_value}" raise ClanError(msg, location=f"{loc}") return [construct_value(get_args(t)[0], item) for item in field_value] - if get_origin(t) is dict and isinstance(field_value, dict): + if origin is dict and isinstance(field_value, dict): return { key: construct_value(get_args(t)[1], value) for key, value in field_value.items() } - if get_origin(t) is Literal: + if origin is Literal: valid_values = get_args(t) if field_value not in valid_values: - msg = f"Expected one of {valid_values}, got {field_value}" + msg = f"Expected one of {', '.join(valid_values)}, got {field_value}" raise ClanError(msg, location=f"{loc}") return field_value - if get_origin(t) is Annotated: + if origin is Enum: + if field_value not in origin.__members__: + msg = f"Expected one of {', '.join(origin.__members__)}, got {field_value}" + raise ClanError(msg, location=f"{loc}") + + if origin is Annotated: (base_type,) = get_args(t) return construct_value(base_type, field_value) diff --git a/pkgs/clan-cli/clan_cli/api/util.py b/pkgs/clan-cli/clan_cli/api/util.py index 45c762a22..3ff0ffc8c 100644 --- a/pkgs/clan-cli/clan_cli/api/util.py +++ b/pkgs/clan-cli/clan_cli/api/util.py @@ -2,6 +2,7 @@ import copy import dataclasses import pathlib from dataclasses import MISSING +from enum import EnumType from types import NoneType, UnionType from typing import ( Annotated, @@ -77,13 +78,16 @@ def type_to_dict( if dataclasses.is_dataclass(t): fields = dataclasses.fields(t) - properties = { - f.metadata.get("alias", f.name): type_to_dict( + properties = {} + for f in fields: + if f.name.startswith("_"): + continue + assert not isinstance( + f.type, str + ), f"Expected field type to be a type, got {f.type}, Have you imported `from __future__ import annotations`?" + properties[f.metadata.get("alias", f.name)] = type_to_dict( f.type, f"{scope} {t.__name__}.{f.name}", type_map ) - for f in fields - if not f.name.startswith("_") - } required = set() for pn, pv in properties.items(): @@ -192,6 +196,11 @@ def type_to_dict( return {"type": "boolean"} if t is object: return {"type": "object"} + if type(t) is EnumType: + return { + "type": "string", + "enum": list(t.__members__), + } if t is Any: msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}" raise JSchemaTypeError(msg) @@ -208,7 +217,7 @@ def type_to_dict( if t is NoneType: return {"type": "null"} - msg = f"{scope} - Error primitive type not supported {t!s}" + msg = f"{scope} - Basic type '{t!s}' is not supported" raise JSchemaTypeError(msg) - msg = f"{scope} - Error type not supported {t!s}" + msg = f"{scope} - Type '{t!s}' is not supported" raise JSchemaTypeError(msg) diff --git a/pkgs/clan-cli/clan_cli/secrets/sops.py b/pkgs/clan-cli/clan_cli/secrets/sops.py index d342b8f3f..1007faaaf 100644 --- a/pkgs/clan-cli/clan_cli/secrets/sops.py +++ b/pkgs/clan-cli/clan_cli/secrets/sops.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import enum import io import json @@ -27,7 +25,7 @@ class KeyType(enum.Enum): PGP = enum.auto() @classmethod - def validate(cls, value: str | None) -> KeyType | None: # noqa: ANN102 + def validate(cls, value: str | None) -> "KeyType | None": # noqa: ANN102 if value: return cls.__members__.get(value.upper()) return None