diff --git a/pkgs/clan-cli/clan_cli/machines/create.py b/pkgs/clan-cli/clan_cli/machines/create.py index 05fde576d..7b59d8b8f 100644 --- a/pkgs/clan-cli/clan_cli/machines/create.py +++ b/pkgs/clan-cli/clan_cli/machines/create.py @@ -2,7 +2,6 @@ import argparse import logging import re from dataclasses import dataclass -from typing import TypeVar, cast from clan_lib.api import API from clan_lib.dirs import get_clan_flake_toplevel_or_env @@ -12,7 +11,7 @@ from clan_lib.git import commit_file from clan_lib.nix_models.clan import InventoryMachine from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy from clan_lib.persist.inventory_store import InventoryStore -from clan_lib.persist.util import set_value_by_path +from clan_lib.persist.util import merge_objects, set_value_by_path from clan_lib.templates.handler import machine_template from clan_cli.completions import add_dynamic_completer, complete_tags @@ -28,41 +27,6 @@ class CreateOptions: target_host: str | None = None -T = TypeVar("T") - - -def merge_objects(obj1: T, obj2: T) -> T: - """ - Updates values in obj2 by values of Obj1 - The output contains values for all keys of Obj1 and Obj2 together - - Lists are deduplicated and appended almost like in the nix module system. - """ - result = {} - msg = f"cannot update non-dictionary values: {obj2} by {obj1}" - if not isinstance(obj1, dict): - raise ClanError(msg) - if not isinstance(obj2, dict): - raise ClanError(msg) - - all_keys = set(obj1.keys()).union(obj2.keys()) - - for key in all_keys: - val1 = obj1.get(key) - val2 = obj2.get(key) - - if isinstance(val1, dict) and isinstance(val2, dict): - result[key] = merge_objects(val1, val2) - elif isinstance(val1, list) and isinstance(val2, list): - result[key] = list(dict.fromkeys(val2 + val1)) # type: ignore - elif key in obj1: - result[key] = val1 # type: ignore - elif key in obj2: - result[key] = val2 # type: ignore - - return cast(T, result) - - @API.register def create_machine( opts: CreateOptions, @@ -122,7 +86,7 @@ def create_machine( inventory = inventory_store.read() curr_machine = inventory.get("machines", {}).get(machine_name, {}) - new_machine = merge_objects(opts.machine, curr_machine) + new_machine = merge_objects(curr_machine, opts.machine) set_value_by_path( inventory, diff --git a/pkgs/clan-cli/clan_lib/persist/util.py b/pkgs/clan-cli/clan_lib/persist/util.py index 382827dc1..23bbd6585 100644 --- a/pkgs/clan-cli/clan_lib/persist/util.py +++ b/pkgs/clan-cli/clan_lib/persist/util.py @@ -3,11 +3,67 @@ Utilities for working with nested dictionaries, particularly for flattening, unmerging lists, finding duplicates, and calculating patches. """ +import json from collections import Counter -from typing import Any +from typing import Any, TypeVar, cast from clan_lib.errors import ClanError +T = TypeVar("T") + +empty: list[str] = [] + + +def merge_objects( + curr: T, update: T, merge_lists: bool = True, path: list[str] = empty +) -> T: + """ + Updates values in curr by values of update + The output contains values for all keys of curr and update together. + + Lists are deduplicated and appended almost like in the nix module system. + + Example: + + merge_objects({"a": 1}, {"a": null }) -> {"a": null} + merge_objects({"a": null}, {"a": 1 }) -> {"a": 1} + """ + result = {} + msg = f"cannot update non-dictionary values: {curr} by {update}" + if not isinstance(update, dict): + raise ClanError(msg) + if not isinstance(curr, dict): + raise ClanError(msg) + + all_keys = set(update.keys()).union(curr.keys()) + + for key in all_keys: + curr_val = curr.get(key) + update_val = update.get(key) + + if isinstance(update_val, dict) and isinstance(curr_val, dict): + result[key] = merge_objects( + curr_val, update_val, merge_lists=merge_lists, path=[*path, key] + ) + elif isinstance(update_val, list) and isinstance(curr_val, list): + if merge_lists: + result[key] = list(dict.fromkeys(curr_val + update_val)) # type: ignore + else: + result[key] = update_val # type: ignore + elif ( + update_val is not None + and curr_val is not None + and type(update_val) is not type(curr_val) + ): + msg = f"Type mismatch for key '{key}'. Cannot update {type(curr_val)} with {type(update_val)}" + raise ClanError(msg, location=json.dumps([*path, key])) + elif key in update: + result[key] = update_val # type: ignore + elif key in curr: + result[key] = curr_val # type: ignore + + return cast(T, result) + def path_match(path: list[str], whitelist_paths: list[list[str]]) -> bool: """ diff --git a/pkgs/clan-cli/clan_lib/persist/util_test.py b/pkgs/clan-cli/clan_lib/persist/util_test.py index b0aacf6bb..84889a866 100644 --- a/pkgs/clan-cli/clan_lib/persist/util_test.py +++ b/pkgs/clan-cli/clan_lib/persist/util_test.py @@ -9,6 +9,7 @@ from clan_lib.persist.util import ( calc_patches, delete_by_path, determine_writeability, + merge_objects, path_match, set_value_by_path, unmerge_lists, @@ -669,3 +670,134 @@ def test_delete_non_existent_path_deep() -> None: assert "not found" in str(excinfo.value) # Data remains unchanged assert data == {"foo": {"bar": {"baz": 123}}} + + +### Merge Objects Tests ### + + +def test_merge_objects_empty() -> None: + obj1 = {} # type: ignore + obj2 = {} # type: ignore + + merged = merge_objects(obj1, obj2) + + assert merged == {} + + +def test_merge_objects_basic() -> None: + obj1 = {"a": 1, "b": 2} + obj2 = {"b": 3, "c": 4} + + merged = merge_objects(obj1, obj2) + + assert merged == {"a": 1, "b": 3, "c": 4} + + +def test_merge_objects_simple() -> None: + obj1 = {"a": 1} + obj2 = {"a": None} + + # merge_objects should update obj2 with obj1 + # Set a value to None + merged_order = merge_objects(obj1, obj2) + + assert merged_order == {"a": None} + + # Test reverse merge + # Set a value from None to 1 + merged_reverse = merge_objects(obj2, obj1) + + assert merged_reverse == {"a": 1} + + +def test_merge_none_to_value() -> None: + obj1 = { + "a": None, + } + obj2 = {"a": {"b": 1}} + + merged_obj = merge_objects(obj1, obj2) + assert merged_obj == {"a": {"b": 1}} + + obj3 = {"a": [1, 2, 3]} + merged_list = merge_objects(obj1, obj3) + assert merged_list == {"a": [1, 2, 3]} + + obj4 = {"a": 1} + merged_int = merge_objects(obj1, obj4) + assert merged_int == {"a": 1} + + obj5 = {"a": "test"} + merged_str = merge_objects(obj1, obj5) + assert merged_str == {"a": "test"} + + obj6 = {"a": True} + merged_bool = merge_objects(obj1, obj6) + assert merged_bool == {"a": True} + + +def test_merge_objects_value_to_none() -> None: + obj1 = {"a": {"b": 1}} + obj2 = {"a": None} + + merged_obj = merge_objects(obj1, obj2) + assert merged_obj == {"a": None} + + obj3 = {"a": [1, 2, 3]} + merged_list = merge_objects(obj3, obj2) + assert merged_list == {"a": None} + + obj4 = {"a": 1} + merged_int = merge_objects(obj4, obj2) + assert merged_int == {"a": None} + + obj5 = {"a": "test"} + merged_str = merge_objects(obj5, obj2) + assert merged_str == {"a": None} + + obj6 = {"a": True} + merged_bool = merge_objects(obj6, obj2) + assert merged_bool == {"a": None} + + +def test_merge_objects_nested() -> None: + obj1 = {"a": {"b": 1, "c": 2}, "d": 3} + obj2 = {"a": {"b": 4}, "e": 5} + + merged = merge_objects(obj1, obj2) + + assert merged == {"a": {"b": 4, "c": 2}, "d": 3, "e": 5} + + +def test_merge_objects_lists() -> None: + obj1 = {"a": [1, 2], "b": {"c": [3, 4]}} + obj2 = {"a": [2, 3], "b": {"c": [4, 5]}} + + merged = merge_objects(obj1, obj2) + + # Lists get merged and deduplicated + # Lists (get sorted, but that is not important) + # Maybe we shouldn't sort them? + assert merged == {"a": [1, 2, 3], "b": {"c": [3, 4, 5]}} + + +def test_merge_objects_unset_list_elements() -> None: + obj1 = {"a": [1, 2], "b": {"c": [3, 4]}} + obj2 = {"a": [], "b": {"c": [5]}} + + merged = merge_objects(obj1, obj2, merge_lists=False) + + # Lists get merged and deduplicated + # None values are not removed + assert merged == {"a": [], "b": {"c": [5]}} + + +def test_merge_objects_with_mismatching_nesting() -> None: + obj1 = {"a": {"b": 1}, "c": 2} + obj2 = {"a": 3} + + # Merging should raise an error because obj1 and obj2 have different nesting for 'a' + with pytest.raises(ClanError) as excinfo: + merge_objects(obj1, obj2) + + assert "Type mismatch for key 'a'" in str(excinfo.value)