diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 78de6a533..85c5f796a 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -319,5 +319,4 @@ def from_dict( msg = f"{data} is not a dict. Expected {t}" raise ClanError(msg) return construct_dataclass(t, data, path) # type: ignore - # breakpoint() return construct_value(t, data, path) diff --git a/pkgs/clan-cli/clan_cli/inventory/__init__.py b/pkgs/clan-cli/clan_cli/inventory/__init__.py index 5903aeec1..25333a9c6 100644 --- a/pkgs/clan-cli/clan_cli/inventory/__init__.py +++ b/pkgs/clan-cli/clan_cli/inventory/__init__.py @@ -119,9 +119,26 @@ def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict return flattened +def unmerge_lists(curr: list, prev: list) -> list: + """ + Unmerge the current list. Given a previous list. + + Returns: + The other list. + """ + # Unmerge the lists + unmerged = [] + for value in curr: + if value not in prev: + unmerged.append(value) + + return unmerged + + def determine_writeability( - data: dict, - correlated: dict, + priorities: dict, + defaults: dict, + persisted: dict, parent_key: str = "", parent_prio: int | None = None, results: dict | None = None, @@ -130,7 +147,7 @@ def determine_writeability( if results is None: results = {"writeable": set({}), "non_writeable": set({})} - for key, value in data.items(): + for key, value in priorities.items(): if key == "__prio": continue @@ -149,6 +166,7 @@ def determine_writeability( if isinstance(value, dict): determine_writeability( value, + defaults, {}, # Children won't be writeable, so correlation doesn't matter here full_key, prio, # Pass the same priority down @@ -159,13 +177,22 @@ def determine_writeability( continue # Check if the key is writeable otherwise - key_in_correlated = key in correlated + key_in_correlated = key in persisted if prio is None: msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable." raise ClanError(msg) - has_children = any(k != "__prio" for k in value) - is_writeable = prio > 100 or key_in_correlated or has_children + is_mergeable = False + if prio == 100: + default = defaults.get(key) + if isinstance(default, dict): + is_mergeable = True + if isinstance(default, list): + is_mergeable = True + if key_in_correlated: + is_mergeable = True + + is_writeable = prio > 100 or is_mergeable # Append the result if is_writeable: @@ -177,7 +204,8 @@ def determine_writeability( if isinstance(value, dict): determine_writeability( value, - correlated.get(key, {}), + defaults.get(key, {}), + persisted.get(key, {}), full_key, prio, # Pass down current priority results, diff --git a/pkgs/clan-cli/tests/test_patch_inventory.py b/pkgs/clan-cli/tests/test_patch_inventory.py index 3a10088b0..76246a924 100644 --- a/pkgs/clan-cli/tests/test_patch_inventory.py +++ b/pkgs/clan-cli/tests/test_patch_inventory.py @@ -1,5 +1,5 @@ # Functions to test -from clan_cli.inventory import determine_writeability, patch +from clan_cli.inventory import determine_writeability, patch, unmerge_lists # --------- Patching tests --------- @@ -49,8 +49,9 @@ def test_write_simple() -> None: }, } + default: dict = {"foo": {}} data: dict = {} - res = determine_writeability(prios, data) + res = determine_writeability(prios, default, data) assert res == {"writeable": {"foo", "foo.bar"}, "non_writeable": set({})} @@ -67,7 +68,7 @@ def test_write_inherited() -> None: } data: dict = {} - res = determine_writeability(prios, data) + res = determine_writeability(prios, {"foo": {"bar": {}}}, data) assert res == { "writeable": {"foo", "foo.bar", "foo.bar.baz"}, "non_writeable": set(), @@ -86,13 +87,34 @@ def test_non_write_inherited() -> None: } data: dict = {} - res = determine_writeability(prios, data) + res = determine_writeability(prios, {}, data) assert res == { "writeable": set(), "non_writeable": {"foo", "foo.bar", "foo.bar.baz"}, } +def test_write_list() -> None: + prios = { + "foo": { + "__prio": 100, + }, + } + + data: dict = {} + default: dict = { + "foo": [ + "a", + "b", + ] # <- writeable: because lists are merged. Filtering out nix-values comes later + } + res = determine_writeability(prios, default, data) + assert res == { + "writeable": {"foo"}, + "non_writeable": set(), + } + + def test_write_because_written() -> None: prios = { "foo": { @@ -107,7 +129,7 @@ def test_write_because_written() -> None: # Given the following data. {} # Check that the non-writeable paths are correct. - res = determine_writeability(prios, {}) + res = determine_writeability(prios, {"foo": {"bar": {}}}, {}) assert res == { "writeable": {"foo", "foo.bar"}, "non_writeable": {"foo.bar.baz", "foo.bar.foobar"}, @@ -120,8 +142,19 @@ def test_write_because_written() -> None: } } } - res = determine_writeability(prios, data) + res = determine_writeability(prios, {}, data) assert res == { "writeable": {"foo", "foo.bar", "foo.bar.baz"}, "non_writeable": {"foo.bar.foobar"}, } + + +# --------- List unmerge tests --------- + + +def test_list_unmerge() -> None: + all_machines = ["machineA", "machineB"] + inventory = ["machineB"] + + nix_machines = unmerge_lists(inventory, all_machines) + assert nix_machines == ["machineA"]