From e21bfbc25745c47356bcc718418fe89993e7670d Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Tue, 30 Jul 2024 12:15:24 +0200 Subject: [PATCH 1/2] Deserializer: replace pydantic --- pkgs/clan-cli/clan_cli/api/serde.py | 160 ++++++++++++++++++---- pkgs/clan-cli/default.nix | 3 - pkgs/clan-cli/tests/test_deserializers.py | 2 +- 3 files changed, 138 insertions(+), 27 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 57345c3db..605b04099 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -29,17 +29,19 @@ Dependencies: Note: This module assumes the presence of other modules and classes such as `ClanError` and `ErrorDetails` from the `clan_cli.errors` module. """ +import dataclasses import json from dataclasses import dataclass, fields, is_dataclass from pathlib import Path +from types import UnionType from typing import ( Any, TypeVar, + Union, + get_args, + get_origin, ) -from pydantic import TypeAdapter, ValidationError -from pydantic_core import ErrorDetails - from clan_cli.errors import ClanError @@ -83,24 +85,136 @@ def dataclass_to_dict(obj: Any, *, use_alias: bool = True) -> Any: T = TypeVar("T", bound=dataclass) # type: ignore -def from_dict(t: type[T], data: Any) -> T: - """ - Dynamically instantiate a data class from a dictionary, handling nested data classes. - We use dataclasses. But the deserialization logic of pydantic takes a lot of complexity. - """ - adapter = TypeAdapter(t) - try: - return adapter.validate_python( - data, - ) - except ValidationError as e: - fst_error: ErrorDetails = e.errors()[0] - if not fst_error: - raise ClanError(msg=str(e)) +def is_union_type(type_hint: type | UnionType) -> bool: + return ( + type(type_hint) is UnionType + or isinstance(type_hint, UnionType) + or get_origin(type_hint) is Union + ) - msg = fst_error.get("msg") - loc = fst_error.get("loc") - field_path = "Unknown" - if loc: - field_path = str(loc) - raise ClanError(msg=msg, location=f"{t!s}: {field_path}", description=str(e)) + +def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool: + if get_origin(union_type) is UnionType: + return any(issubclass(arg, target_type) for arg in get_args(union_type)) + return False + + +def unwrap_none_type(type_hint: type | UnionType) -> type: + """ + Takes a type union and returns the first non-None type. + None | str + => + str + """ + + if is_union_type(type_hint): + # Return the first non-None type + return next(t for t in get_args(type_hint) if t is not type(None)) + + return type_hint # type: ignore + + +JsonValue = str | float | dict[str, Any] | list[Any] | None + + +def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any: + """ + Construct a field value from a type hint and a field value. + """ + # If the field is another dataclass + # Field_value must be a dictionary + if is_dataclass(t) and isinstance(field_value, dict): + return from_dict(t, field_value) + + # If the field expects a path + # Field_value must be a string + elif issubclass(t, Path) or is_type_in_union(t, Path): + if not isinstance(field_value, str): + raise ClanError( + f"Expected string, cannot construct pathlib.Path() from: {field_value} ", + location=f"{loc}", + ) + + return Path(field_value) + + # Trivial values + elif t is str: + if not isinstance(field_value, str): + raise ClanError(f"Expected string, got {field_value}", location=f"{loc}") + + return field_value + + elif t is int and not isinstance(field_value, str): + return int(field_value) # type: ignore + elif t is float and not isinstance(field_value, str): + return float(field_value) # type: ignore + + # Union types construct the first non-None type + elif is_union_type(t): + # Unwrap the union type + t = unwrap_none_type(t) + # Construct the field value + return construct_field(t, field_value) + + # Nested types + # list + # dict + elif get_origin(t) is list: + if not isinstance(field_value, list): + raise ClanError(f"Expected list, got {field_value}", location=f"{loc}") + + return [construct_field(get_args(t)[0], item) for item in field_value] + elif get_origin(t) is dict and isinstance(field_value, dict): + return { + key: construct_field(get_args(t)[1], value) + for key, value in field_value.items() + } + + # Unhandled + else: + raise ClanError(f"Unhandled field type {t} with value {field_value}") + + +def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T: + """ + type t MUST be a dataclass + Dynamically instantiate a data class from a dictionary, handling nested data classes. + """ + if not is_dataclass(t): + raise ClanError(f"{t.__name__} is not a dataclass") + + # Attempt to create an instance of the data_class# + field_values: dict[str, Any] = {} + required: list[str] = [] + + for field in fields(t): + if field.name.startswith("_"): + continue + # The first type in a Union + # str <- None | str | Path + field_type: type[Any] = unwrap_none_type(field.type) # type: ignore + data_field_name = field.metadata.get("alias", field.name) + + # Check if the field is required + # breakpoint() + if ( + field.default is dataclasses.MISSING + and field.default_factory is dataclasses.MISSING + ): + required.append(field.name) + + # Populate the field_values dictionary with the field value + # if present in the data + if data_field_name in data: + field_value = data.get(data_field_name) + field_values[field.name] = construct_field(field_type, field_value) + + # Check that all required field are present. + for field_name in required: + if field_name not in field_values: + formatted_path = " ".join(path) + raise ClanError( + f"Required field missing: '{field_name}' in {t} {formatted_path}, got Value: {data}" + ) + + return t(**field_values) # type: ignore diff --git a/pkgs/clan-cli/default.nix b/pkgs/clan-cli/default.nix index 12d8db646..794869db6 100644 --- a/pkgs/clan-cli/default.nix +++ b/pkgs/clan-cli/default.nix @@ -17,8 +17,6 @@ setuptools, stdenv, - pydantic, - # custom args clan-core-path, nixpkgs, @@ -30,7 +28,6 @@ let pythonDependencies = [ argcomplete # Enables shell completions - pydantic # Dataclass deserialisation / validation / schemas ]; # load nixpkgs runtime dependencies from a json file diff --git a/pkgs/clan-cli/tests/test_deserializers.py b/pkgs/clan-cli/tests/test_deserializers.py index f87b65adc..7526f342e 100644 --- a/pkgs/clan-cli/tests/test_deserializers.py +++ b/pkgs/clan-cli/tests/test_deserializers.py @@ -45,11 +45,11 @@ def test_nested() -> None: class Person: name: str # deeply nested dataclasses + home: Path | str | None age: Age age_list: list[Age] age_dict: dict[str, Age] # Optional field - home: Path | None person_dict = { "name": "John", From 4b2d1b7923ce7491b66cee6864fecfd4fb15963b Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Tue, 30 Jul 2024 14:16:03 +0200 Subject: [PATCH 2/2] Deserializer: add Nullable fields --- pkgs/clan-cli/clan_cli/api/serde.py | 14 ++++++++++---- pkgs/clan-cli/tests/test_deserializers.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 605b04099..24b9786d5 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -96,7 +96,7 @@ def is_union_type(type_hint: type | UnionType) -> bool: def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool: if get_origin(union_type) is UnionType: return any(issubclass(arg, target_type) for arg in get_args(union_type)) - return False + return union_type == target_type def unwrap_none_type(type_hint: type | UnionType) -> type: @@ -121,6 +121,8 @@ def construct_field(t: type, field_value: JsonValue, loc: list[str] = []) -> Any """ Construct a field value from a type hint and a field value. """ + if t is None and field_value: + raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}") # If the field is another dataclass # Field_value must be a dictionary if is_dataclass(t) and isinstance(field_value, dict): @@ -195,8 +197,6 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T: field_type: type[Any] = unwrap_none_type(field.type) # type: ignore data_field_name = field.metadata.get("alias", field.name) - # Check if the field is required - # breakpoint() if ( field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING @@ -207,7 +207,13 @@ def from_dict(t: type[T], data: dict[str, Any], path: list[str] = []) -> T: # if present in the data if data_field_name in data: field_value = data.get(data_field_name) - field_values[field.name] = construct_field(field_type, field_value) + + if field_value is None and ( + field.type is None or is_type_in_union(field.type, type(None)) + ): + field_values[field.name] = None + else: + field_values[field.name] = construct_field(field_type, field_value) # Check that all required field are present. for field_name in required: diff --git a/pkgs/clan-cli/tests/test_deserializers.py b/pkgs/clan-cli/tests/test_deserializers.py index 7526f342e..3d6d7583d 100644 --- a/pkgs/clan-cli/tests/test_deserializers.py +++ b/pkgs/clan-cli/tests/test_deserializers.py @@ -83,6 +83,29 @@ def test_simple_field_missing() -> None: from_dict(Person, person_dict) +def test_nullable() -> None: + @dataclass + class Person: + name: None + + person_dict = { + "name": None, + } + + from_dict(Person, person_dict) + + +def test_nullable_non_exist() -> None: + @dataclass + class Person: + name: None + + person_dict = {} + + with pytest.raises(ClanError): + from_dict(Person, person_dict) + + def test_deserialize_extensive_inventory() -> None: # TODO: Make this an abstract test, so it doesn't break the test if the inventory changes data = {