diff --git a/lib/inventory/build-inventory/interface.nix b/lib/inventory/build-inventory/interface.nix index ea3e45422..2cf922197 100644 --- a/lib/inventory/build-inventory/interface.nix +++ b/lib/inventory/build-inventory/interface.nix @@ -362,7 +362,7 @@ in services.borgbackup."instance_1" = { roles.client.machines = ["machineA"]; - machineA.config = { + machines.machineA.config = { # Additional specific config for the machine # This is merged with all other config places }; diff --git a/pkgs/clan-cli/clan_cli/api/serde.py b/pkgs/clan-cli/clan_cli/api/serde.py index 78de6a533..85c5f796a 100644 --- a/pkgs/clan-cli/clan_cli/api/serde.py +++ b/pkgs/clan-cli/clan_cli/api/serde.py @@ -319,5 +319,4 @@ def from_dict( msg = f"{data} is not a dict. Expected {t}" raise ClanError(msg) return construct_dataclass(t, data, path) # type: ignore - # breakpoint() return construct_value(t, data, path) diff --git a/pkgs/clan-cli/clan_cli/inventory/__init__.py b/pkgs/clan-cli/clan_cli/inventory/__init__.py index 5903aeec1..1aa634d72 100644 --- a/pkgs/clan-cli/clan_cli/inventory/__init__.py +++ b/pkgs/clan-cli/clan_cli/inventory/__init__.py @@ -14,6 +14,7 @@ Operate on the returned inventory to make changes import contextlib import json +from collections import Counter from pathlib import Path from typing import Any @@ -94,7 +95,6 @@ def load_inventory_eval(flake_dir: str | Path) -> Inventory: def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict: """ Recursively flattens a nested dictionary structure where keys are joined by the separator. - The flattened dictionary contains only entries with "__prio". Args: data (dict): The nested dictionary structure. @@ -102,26 +102,110 @@ def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict separator (str): The string to use for joining keys. Returns: - dict: A flattened dictionary with "__prio" values. + dict: A flattened dictionary with all values. Directly in the root. """ flattened = {} for key, value in data.items(): new_key = f"{parent_key}{separator}{key}" if parent_key else key - if isinstance(value, dict) and "__prio" in value: - flattened[new_key] = {"__prio": value["__prio"]} - if isinstance(value, dict): # Recursively flatten the nested dictionary flattened.update(flatten_data(value, new_key, separator)) + else: + flattened[new_key] = value return flattened +def unmerge_lists(all_items: list, filter_items: list) -> list: + """ + Unmerge the current list. Given a previous list. + + Returns: + The other list. + """ + # Unmerge the lists + res = [] + for value in all_items: + if value not in filter_items: + res.append(value) + + return res + + +def find_duplicates(string_list: list[str]) -> list[str]: + count = Counter(string_list) + duplicates = [item for item, freq in count.items() if freq > 1] + return duplicates + + +def calc_patches( + persisted: dict, update: dict, all_values: dict, writeables: dict +) -> dict[str, Any]: + """ + Calculate the patches to apply to the inventory. + + Given its current state and the update to apply. + + Filters out nix-values so it doesnt matter if the anyone sends them. + + : param persisted: The current state of the inventory. + : param update: The update to apply. + : param writeable: The writeable keys. Use 'determine_writeability'. + Example: {'writeable': {'foo', 'foo.bar'}, 'non_writeable': {'foo.nix'}} + : param all_values: All values in the inventory retrieved from the flake evaluation. + """ + persisted_flat = flatten_data(persisted) + update_flat = flatten_data(update) + all_values_flat = flatten_data(all_values) + + patchset = {} + for update_key, update_data in update_flat.items(): + if update_key in writeables["non_writeable"]: + if update_data != all_values_flat.get(update_key): + msg = f"Key '{update_key}' is not writeable." + raise ClanError(msg) + continue + + if update_key in writeables["writeable"]: + if type(update_data) is not type(all_values_flat.get(update_key)): + msg = f"Type mismatch for key '{update_key}'. Cannot update {type(all_values_flat.get(update_key))} with {type(update_data)}" + raise ClanError(msg) + + # Handle list seperation + if isinstance(update_data, list): + duplicates = find_duplicates(update_data) + if duplicates: + msg = f"Key '{update_key}' contains duplicates: {duplicates}. This not supported yet." + raise ClanError(msg) + # List of current values + persisted_data = persisted_flat.get(update_key, []) + # List including nix values + all_list = all_values_flat.get(update_key, []) + nix_list = unmerge_lists(all_list, persisted_data) + if update_data != all_list: + patchset[update_key] = unmerge_lists(update_data, nix_list) + + elif update_data != persisted_flat.get(update_key, None): + patchset[update_key] = update_data + + continue + + if update_key not in all_values_flat: + msg = f"Key '{update_key}' cannot be set. It does not exist." + raise ClanError(msg) + + msg = f"Cannot determine writeability for key '{update_key}'" + raise ClanError(msg) + + return patchset + + def determine_writeability( - data: dict, - correlated: dict, + priorities: dict, + defaults: dict, + persisted: dict, parent_key: str = "", parent_prio: int | None = None, results: dict | None = None, @@ -130,7 +214,7 @@ def determine_writeability( if results is None: results = {"writeable": set({}), "non_writeable": set({})} - for key, value in data.items(): + for key, value in priorities.items(): if key == "__prio": continue @@ -149,6 +233,7 @@ def determine_writeability( if isinstance(value, dict): determine_writeability( value, + defaults, {}, # Children won't be writeable, so correlation doesn't matter here full_key, prio, # Pass the same priority down @@ -159,13 +244,22 @@ def determine_writeability( continue # Check if the key is writeable otherwise - key_in_correlated = key in correlated + key_in_correlated = key in persisted if prio is None: msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable." raise ClanError(msg) - has_children = any(k != "__prio" for k in value) - is_writeable = prio > 100 or key_in_correlated or has_children + is_mergeable = False + if prio == 100: + default = defaults.get(key) + if isinstance(default, dict): + is_mergeable = True + if isinstance(default, list): + is_mergeable = True + if key_in_correlated: + is_mergeable = True + + is_writeable = prio > 100 or is_mergeable # Append the result if is_writeable: @@ -177,7 +271,8 @@ def determine_writeability( if isinstance(value, dict): determine_writeability( value, - correlated.get(key, {}), + defaults.get(key, {}), + persisted.get(key, {}), full_key, prio, # Pass down current priority results, diff --git a/pkgs/clan-cli/tests/test_patch_inventory.py b/pkgs/clan-cli/tests/test_patch_inventory.py index 3a10088b0..969ecfc36 100644 --- a/pkgs/clan-cli/tests/test_patch_inventory.py +++ b/pkgs/clan-cli/tests/test_patch_inventory.py @@ -1,5 +1,12 @@ # Functions to test -from clan_cli.inventory import determine_writeability, patch +import pytest +from clan_cli.errors import ClanError +from clan_cli.inventory import ( + calc_patches, + determine_writeability, + patch, + unmerge_lists, +) # --------- Patching tests --------- @@ -49,8 +56,9 @@ def test_write_simple() -> None: }, } + default: dict = {"foo": {}} data: dict = {} - res = determine_writeability(prios, data) + res = determine_writeability(prios, default, data) assert res == {"writeable": {"foo", "foo.bar"}, "non_writeable": set({})} @@ -67,7 +75,7 @@ def test_write_inherited() -> None: } data: dict = {} - res = determine_writeability(prios, data) + res = determine_writeability(prios, {"foo": {"bar": {}}}, data) assert res == { "writeable": {"foo", "foo.bar", "foo.bar.baz"}, "non_writeable": set(), @@ -86,13 +94,34 @@ def test_non_write_inherited() -> None: } data: dict = {} - res = determine_writeability(prios, data) + res = determine_writeability(prios, {}, data) assert res == { "writeable": set(), "non_writeable": {"foo", "foo.bar", "foo.bar.baz"}, } +def test_write_list() -> None: + prios = { + "foo": { + "__prio": 100, + }, + } + + data: dict = {} + default: dict = { + "foo": [ + "a", + "b", + ] # <- writeable: because lists are merged. Filtering out nix-values comes later + } + res = determine_writeability(prios, default, data) + assert res == { + "writeable": {"foo"}, + "non_writeable": set(), + } + + def test_write_because_written() -> None: prios = { "foo": { @@ -107,7 +136,7 @@ def test_write_because_written() -> None: # Given the following data. {} # Check that the non-writeable paths are correct. - res = determine_writeability(prios, {}) + res = determine_writeability(prios, {"foo": {"bar": {}}}, {}) assert res == { "writeable": {"foo", "foo.bar"}, "non_writeable": {"foo.bar.baz", "foo.bar.foobar"}, @@ -120,8 +149,247 @@ def test_write_because_written() -> None: } } } - res = determine_writeability(prios, data) + res = determine_writeability(prios, {}, data) assert res == { "writeable": {"foo", "foo.bar", "foo.bar.baz"}, "non_writeable": {"foo.bar.foobar"}, } + + +# --------- List unmerge tests --------- + + +def test_list_unmerge() -> None: + all_machines = ["machineA", "machineB"] + inventory = ["machineB"] + + nix_machines = unmerge_lists(all_machines, inventory) + assert nix_machines == ["machineA"] + + +# --------- Write tests --------- + + +def test_update_simple() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + "bar": {"__prio": 1000}, # <- writeable: mkDefault "foo.bar" + "nix": {"__prio": 100}, # <- non writeable: "foo.bar" (defined in nix) + }, + } + + data_eval = {"foo": {"bar": "baz", "nix": "this is set in nix"}} + + data_disk: dict = {} + + writeables = determine_writeability(prios, data_eval, data_disk) + + assert writeables == {"writeable": {"foo", "foo.bar"}, "non_writeable": {"foo.nix"}} + + update = { + "foo": { + "bar": "new value", # <- user sets this value + "nix": "this is set in nix", # <- user didnt touch this value + # If the user would have set this value, it would trigger an error + } + } + patchset = calc_patches( + data_disk, update, all_values=data_eval, writeables=writeables + ) + + assert patchset == {"foo.bar": "new value"} + + +def test_update_many() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + "bar": {"__prio": 100}, # <- + "nix": {"__prio": 100}, # <- non writeable: "foo.bar" (defined in nix) + "nested": { + "__prio": 100, + "x": {"__prio": 100}, # <- writeable: "foo.nested.x" + "y": {"__prio": 100}, # <- non-writeable: "foo.nested.y" + }, + }, + } + + data_eval = { + "foo": { + "bar": "baz", + "nix": "this is set in nix", + "nested": {"x": "x", "y": "y"}, + } + } + + data_disk = {"foo": {"bar": "baz", "nested": {"x": "x"}}} + + writeables = determine_writeability(prios, data_eval, data_disk) + + assert writeables == { + "writeable": {"foo.nested", "foo", "foo.bar", "foo.nested.x"}, + "non_writeable": {"foo.nix", "foo.nested.y"}, + } + + update = { + "foo": { + "bar": "new value for bar", # <- user sets this value + "nix": "this is set in nix", # <- user cannot set this value + "nested": { + "x": "new value for x", # <- user sets this value + "y": "y", # <- user cannot set this value + }, + } + } + patchset = calc_patches( + data_disk, update, all_values=data_eval, writeables=writeables + ) + + assert patchset == { + "foo.bar": "new value for bar", + "foo.nested.x": "new value for x", + } + + +def test_update_parent_non_writeable() -> None: + prios = { + "foo": { + "__prio": 50, # <- non-writeable: "foo" + "bar": {"__prio": 1000}, # <- writeable: mkDefault "foo.bar" + }, + } + + data_eval = { + "foo": { + "bar": "baz", + } + } + + data_disk = { + "foo": { + "bar": "baz", + } + } + + writeables = determine_writeability(prios, data_eval, data_disk) + + assert writeables == {"writeable": set(), "non_writeable": {"foo", "foo.bar"}} + + update = { + "foo": { + "bar": "new value", # <- user sets this value + } + } + with pytest.raises(ClanError) as error: + calc_patches(data_disk, update, all_values=data_eval, writeables=writeables) + + assert str(error.value) == "Key 'foo.bar' is not writeable." + + +def test_update_list() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + }, + } + + data_eval = { + # [ "A" ] is defined in nix. + "foo": ["A", "B"] + } + + data_disk = {"foo": ["B"]} + + writeables = determine_writeability(prios, data_eval, data_disk) + + assert writeables == {"writeable": {"foo"}, "non_writeable": set()} + + # Add "C" to the list + update = { + "foo": ["A", "B", "C"] # User wants to add "C" + } + + patchset = calc_patches( + data_disk, update, all_values=data_eval, writeables=writeables + ) + + assert patchset == {"foo": ["B", "C"]} + + # Remove "B" from the list + update = { + "foo": ["A"] # User wants to remove "B" + } + + patchset = calc_patches( + data_disk, update, all_values=data_eval, writeables=writeables + ) + + assert patchset == {"foo": []} + + +def test_update_list_duplicates() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + }, + } + + data_eval = { + # [ "A" ] is defined in nix. + "foo": ["A", "B"] + } + + data_disk = {"foo": ["B"]} + + writeables = determine_writeability(prios, data_eval, data_disk) + + assert writeables == {"writeable": {"foo"}, "non_writeable": set()} + + # Add "A" to the list + update = { + "foo": ["A", "B", "A"] # User wants to add duplicate "A" + } + + with pytest.raises(ClanError) as error: + calc_patches(data_disk, update, all_values=data_eval, writeables=writeables) + + assert ( + str(error.value) + == "Key 'foo' contains duplicates: ['A']. This not supported yet." + ) + + +def test_update_mismatching_update_type() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + }, + } + + data_eval = {"foo": ["A", "B"]} + + data_disk: dict = {} + + writeables = determine_writeability(prios, data_eval, data_disk) + + assert writeables == {"writeable": {"foo"}, "non_writeable": set()} + + # set foo.A which doesnt exist + update_1 = {"foo": {"A": "B"}} + + with pytest.raises(ClanError) as error: + calc_patches(data_disk, update_1, all_values=data_eval, writeables=writeables) + + assert str(error.value) == "Key 'foo.A' cannot be set. It does not exist." + + # set foo to an int but it is a list + update_2: dict = {"foo": 1} + + with pytest.raises(ClanError) as error: + calc_patches(data_disk, update_2, all_values=data_eval, writeables=writeables) + + assert ( + str(error.value) + == "Type mismatch for key 'foo'. Cannot update with " + )