fix serialisation of SopsKey type

This commit is contained in:
Jörg Thalheim
2024-10-01 15:40:50 +02:00
committed by Mic92
parent db065ea06b
commit 541a73692f
3 changed files with 29 additions and 15 deletions

View File

@@ -32,6 +32,7 @@ Note: This module assumes the presence of other modules and classes such as `Cla
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass, fields, is_dataclass from dataclasses import dataclass, fields, is_dataclass
from enum import Enum
from pathlib import Path from pathlib import Path
from types import UnionType from types import UnionType
from typing import ( from typing import (
@@ -179,25 +180,31 @@ def construct_value(
# Nested types # Nested types
# list # list
# dict # dict
if get_origin(t) is list: origin = get_origin(t)
if origin is list:
if not isinstance(field_value, list): if not isinstance(field_value, list):
msg = f"Expected list, got {field_value}" msg = f"Expected list, got {field_value}"
raise ClanError(msg, location=f"{loc}") raise ClanError(msg, location=f"{loc}")
return [construct_value(get_args(t)[0], item) for item in field_value] return [construct_value(get_args(t)[0], item) for item in field_value]
if get_origin(t) is dict and isinstance(field_value, dict): if origin is dict and isinstance(field_value, dict):
return { return {
key: construct_value(get_args(t)[1], value) key: construct_value(get_args(t)[1], value)
for key, value in field_value.items() for key, value in field_value.items()
} }
if get_origin(t) is Literal: if origin is Literal:
valid_values = get_args(t) valid_values = get_args(t)
if field_value not in valid_values: if field_value not in valid_values:
msg = f"Expected one of {valid_values}, got {field_value}" msg = f"Expected one of {', '.join(valid_values)}, got {field_value}"
raise ClanError(msg, location=f"{loc}") raise ClanError(msg, location=f"{loc}")
return field_value return field_value
if get_origin(t) is Annotated: if origin is Enum:
if field_value not in origin.__members__:
msg = f"Expected one of {', '.join(origin.__members__)}, got {field_value}"
raise ClanError(msg, location=f"{loc}")
if origin is Annotated:
(base_type,) = get_args(t) (base_type,) = get_args(t)
return construct_value(base_type, field_value) return construct_value(base_type, field_value)

View File

@@ -2,6 +2,7 @@ import copy
import dataclasses import dataclasses
import pathlib import pathlib
from dataclasses import MISSING from dataclasses import MISSING
from enum import EnumType
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import ( from typing import (
Annotated, Annotated,
@@ -77,13 +78,16 @@ def type_to_dict(
if dataclasses.is_dataclass(t): if dataclasses.is_dataclass(t):
fields = dataclasses.fields(t) fields = dataclasses.fields(t)
properties = { properties = {}
f.metadata.get("alias", f.name): type_to_dict( for f in fields:
if f.name.startswith("_"):
continue
assert not isinstance(
f.type, str
), f"Expected field type to be a type, got {f.type}, Have you imported `from __future__ import annotations`?"
properties[f.metadata.get("alias", f.name)] = type_to_dict(
f.type, f"{scope} {t.__name__}.{f.name}", type_map f.type, f"{scope} {t.__name__}.{f.name}", type_map
) )
for f in fields
if not f.name.startswith("_")
}
required = set() required = set()
for pn, pv in properties.items(): for pn, pv in properties.items():
@@ -192,6 +196,11 @@ def type_to_dict(
return {"type": "boolean"} return {"type": "boolean"}
if t is object: if t is object:
return {"type": "object"} return {"type": "object"}
if type(t) is EnumType:
return {
"type": "string",
"enum": list(t.__members__),
}
if t is Any: if t is Any:
msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}" msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}"
raise JSchemaTypeError(msg) raise JSchemaTypeError(msg)
@@ -208,7 +217,7 @@ def type_to_dict(
if t is NoneType: if t is NoneType:
return {"type": "null"} return {"type": "null"}
msg = f"{scope} - Error primitive type not supported {t!s}" msg = f"{scope} - Basic type '{t!s}' is not supported"
raise JSchemaTypeError(msg) raise JSchemaTypeError(msg)
msg = f"{scope} - Error type not supported {t!s}" msg = f"{scope} - Type '{t!s}' is not supported"
raise JSchemaTypeError(msg) raise JSchemaTypeError(msg)

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import enum import enum
import io import io
import json import json
@@ -27,7 +25,7 @@ class KeyType(enum.Enum):
PGP = enum.auto() PGP = enum.auto()
@classmethod @classmethod
def validate(cls, value: str | None) -> KeyType | None: # noqa: ANN102 def validate(cls, value: str | None) -> "KeyType | None": # noqa: ANN102
if value: if value:
return cls.__members__.get(value.upper()) return cls.__members__.get(value.upper())
return None return None