From 6a2e331861b7c79df9380d2e6398fc1984ebca72 Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Thu, 5 Dec 2024 16:18:46 +0100 Subject: [PATCH] inventory/eval: init determine writeability for single inventory options --- pkgs/clan-cli/clan_cli/inventory/__init__.py | 134 +++++++++++++++++++ pkgs/clan-cli/tests/test_patch_inventory.py | 93 ++++++++++++- 2 files changed, 226 insertions(+), 1 deletion(-) diff --git a/pkgs/clan-cli/clan_cli/inventory/__init__.py b/pkgs/clan-cli/clan_cli/inventory/__init__.py index 18f743e8f..5903aeec1 100644 --- a/pkgs/clan-cli/clan_cli/inventory/__init__.py +++ b/pkgs/clan-cli/clan_cli/inventory/__init__.py @@ -91,6 +91,138 @@ def load_inventory_eval(flake_dir: str | Path) -> Inventory: return inventory +def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict: + """ + Recursively flattens a nested dictionary structure where keys are joined by the separator. + The flattened dictionary contains only entries with "__prio". + + Args: + data (dict): The nested dictionary structure. + parent_key (str): The current path to the nested dictionary (used for recursion). + separator (str): The string to use for joining keys. + + Returns: + dict: A flattened dictionary with "__prio" values. + """ + flattened = {} + + for key, value in data.items(): + new_key = f"{parent_key}{separator}{key}" if parent_key else key + + if isinstance(value, dict) and "__prio" in value: + flattened[new_key] = {"__prio": value["__prio"]} + + if isinstance(value, dict): + # Recursively flatten the nested dictionary + flattened.update(flatten_data(value, new_key, separator)) + + return flattened + + +def determine_writeability( + data: dict, + correlated: dict, + parent_key: str = "", + parent_prio: int | None = None, + results: dict | None = None, + non_writeable: bool = False, +) -> dict: + if results is None: + results = {"writeable": set({}), "non_writeable": set({})} + + for key, value in data.items(): + if key == "__prio": + continue + + full_key = f"{parent_key}.{key}" if parent_key else key + + # Determine the priority for the current key + # Inherit from parent if no priority is defined + prio = value.get("__prio", None) + if prio is None: + prio = parent_prio + + # If priority is less than 100, all children are not writeable + # If the parent passed "non_writeable" earlier, this makes all children not writeable + if (prio is not None and prio < 100) or non_writeable: + results["non_writeable"].add(full_key) + if isinstance(value, dict): + determine_writeability( + value, + {}, # Children won't be writeable, so correlation doesn't matter here + full_key, + prio, # Pass the same priority down + results, + # Recursively mark all children as non-writeable + non_writeable=True, + ) + continue + + # Check if the key is writeable otherwise + key_in_correlated = key in correlated + if prio is None: + msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable." + raise ClanError(msg) + + has_children = any(k != "__prio" for k in value) + is_writeable = prio > 100 or key_in_correlated or has_children + + # Append the result + if is_writeable: + results["writeable"].add(full_key) + else: + results["non_writeable"].add(full_key) + + # Recursive + if isinstance(value, dict): + determine_writeability( + value, + correlated.get(key, {}), + full_key, + prio, # Pass down current priority + results, + ) + + return results + + +def get_inventory_current_priority(flake_dir: str | Path) -> dict: + """ + Returns the current priority of the inventory values + + machines = { + __prio = 100; + flash-installer = { + __prio = 100; + deploy = { + targetHost = { __prio = 1500; }; + }; + description = { __prio = 1500; }; + icon = { __prio = 1500; }; + name = { __prio = 1500; }; + tags = { __prio = 1500; }; + }; + } + """ + cmd = nix_eval( + [ + f"{flake_dir}#clanInternals.inventoryValuesPrios", + "--json", + ] + ) + + proc = run_no_stdout(cmd) + + try: + res = proc.stdout.strip() + data = json.loads(res) + except json.JSONDecodeError as e: + msg = f"Error decoding inventory from flake: {e}" + raise ClanError(msg) from e + else: + return data + + @API.register def load_inventory_json( flake_dir: str | Path, default: Inventory = default_inventory @@ -123,6 +255,8 @@ def patch(d: dict[str, Any], path: str, content: Any) -> None: """ Update the value at a specific dot-separated path in a nested dictionary. + If the value didn't exist before, it will be created recursively. + :param d: The dictionary to update. :param path: The dot-separated path to the key (e.g., 'foo.bar'). :param content: The new value to set. diff --git a/pkgs/clan-cli/tests/test_patch_inventory.py b/pkgs/clan-cli/tests/test_patch_inventory.py index c788288bb..3a10088b0 100644 --- a/pkgs/clan-cli/tests/test_patch_inventory.py +++ b/pkgs/clan-cli/tests/test_patch_inventory.py @@ -1,7 +1,8 @@ # Functions to test -from clan_cli.inventory import patch +from clan_cli.inventory import determine_writeability, patch +# --------- Patching tests --------- def test_patch_nested() -> None: orig = {"a": 1, "b": {"a": 2.1, "b": 2.2}, "c": 3} @@ -34,3 +35,93 @@ def test_create_missing_paths() -> None: patch(orig, "a.b.c", "foo") assert orig == {"a": {"b": {"c": "foo"}}} + + +# --------- Write tests --------- +# + + +def test_write_simple() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + "bar": {"__prio": 1000}, # <- writeable: mkDefault "foo.bar" + }, + } + + data: dict = {} + res = determine_writeability(prios, data) + + assert res == {"writeable": {"foo", "foo.bar"}, "non_writeable": set({})} + + +def test_write_inherited() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + "bar": { + # Inherits prio from parent <- writeable: "foo.bar" + "baz": {"__prio": 1000}, # <- writeable: "foo.bar.baz" + }, + }, + } + + data: dict = {} + res = determine_writeability(prios, data) + assert res == { + "writeable": {"foo", "foo.bar", "foo.bar.baz"}, + "non_writeable": set(), + } + + +def test_non_write_inherited() -> None: + prios = { + "foo": { + "__prio": 50, # <- non writeable: mkForce "foo" = {...} + "bar": { + # Inherits prio from parent <- non writeable + "baz": {"__prio": 1000}, # <- non writeable: mkDefault "foo.bar.baz" + }, + }, + } + + data: dict = {} + res = determine_writeability(prios, data) + assert res == { + "writeable": set(), + "non_writeable": {"foo", "foo.bar", "foo.bar.baz"}, + } + + +def test_write_because_written() -> None: + prios = { + "foo": { + "__prio": 100, # <- writeable: "foo" + "bar": { + # Inherits prio from parent <- writeable + "baz": {"__prio": 100}, # <- non writeable usually + "foobar": {"__prio": 100}, # <- non writeable + }, + }, + } + + # Given the following data. {} + # Check that the non-writeable paths are correct. + res = determine_writeability(prios, {}) + assert res == { + "writeable": {"foo", "foo.bar"}, + "non_writeable": {"foo.bar.baz", "foo.bar.foobar"}, + } + + data: dict = { + "foo": { + "bar": { + "baz": "foo" # <- written. Since we created the data, we know we can write to it + } + } + } + res = determine_writeability(prios, data) + assert res == { + "writeable": {"foo", "foo.bar", "foo.bar.baz"}, + "non_writeable": {"foo.bar.foobar"}, + }