From 91034b66bfd5a71d2949436b932e69de9d8cc3e6 Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Wed, 14 May 2025 10:50:47 +0200 Subject: [PATCH] refactor(lib/inventory): use util functions from clan_lib --- pkgs/clan-cli/clan_cli/inventory/__init__.py | 333 +----------------- .../clan_cli/tests/test_patch_inventory.py | 2 +- 2 files changed, 9 insertions(+), 326 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/inventory/__init__.py b/pkgs/clan-cli/clan_cli/inventory/__init__.py index ceaf4abdc..d51963676 100644 --- a/pkgs/clan-cli/clan_cli/inventory/__init__.py +++ b/pkgs/clan-cli/clan_cli/inventory/__init__.py @@ -4,7 +4,7 @@ DEPRECATED: Don't use this module anymore Instead use: -'clan_lib.persistence.inventoryStore' +'clan_lib.persist.inventoryStore' Which is an abstraction over the inventory @@ -12,13 +12,18 @@ Interacting with 'clan_cli.inventory' is NOT recommended and will be removed """ import json -from collections import Counter -from dataclasses import dataclass from pathlib import Path from typing import Any from clan_lib.api import API from clan_lib.nix_models.inventory import Inventory +from clan_lib.persist.inventory_store import WriteInfo +from clan_lib.persist.util import ( + calc_patches, + delete_by_path, + determine_writeability, + patch, +) from clan_cli.cmd import run from clan_cli.errors import ClanError @@ -66,268 +71,6 @@ def load_inventory_eval(flake_dir: Flake) -> Inventory: return inventory -def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict: - """ - Recursively flattens a nested dictionary structure where keys are joined by the separator. - - Args: - data (dict): The nested dictionary structure. - parent_key (str): The current path to the nested dictionary (used for recursion). - separator (str): The string to use for joining keys. - - Returns: - dict: A flattened dictionary with all values. Directly in the root. - """ - flattened = {} - - for key, value in data.items(): - new_key = f"{parent_key}{separator}{key}" if parent_key else key - - if isinstance(value, dict): - # Recursively flatten the nested dictionary - flattened.update(flatten_data(value, new_key, separator)) - else: - flattened[new_key] = value - - return flattened - - -def unmerge_lists(all_items: list, filter_items: list) -> list: - """ - Unmerge the current list. Given a previous list. - - Returns: - The other list. - """ - # Unmerge the lists - res = [] - for value in all_items: - if value not in filter_items: - res.append(value) - - return res - - -def find_duplicates(string_list: list[str]) -> list[str]: - count = Counter(string_list) - duplicates = [item for item, freq in count.items() if freq > 1] - return duplicates - - -def find_deleted_paths( - persisted: dict[str, Any], update: dict[str, Any], parent_key: str = "" -) -> set[str]: - """ - Recursively find keys (at any nesting level) that exist in persisted but do not - exist in update. If a nested dictionary is completely removed, return that dictionary key. - - :param persisted: The original (persisted) nested dictionary. - :param update: The updated nested dictionary (some keys might be removed). - :param parent_key: The dotted path to the current dictionary's location. - :return: A set of dotted paths indicating keys or entire nested paths that were deleted. - """ - deleted_paths = set() - - # Iterate over keys in persisted - for key, p_value in persisted.items(): - current_path = f"{parent_key}.{key}" if parent_key else key - # Check if this key exists in update - if key not in update: - # Key doesn't exist at all -> entire branch deleted - deleted_paths.add(current_path) - else: - u_value = update[key] - # If persisted value is dict, check the update value - if isinstance(p_value, dict): - if isinstance(u_value, dict): - # If persisted dict is non-empty but updated dict is empty, - # that means everything under this branch is removed. - if p_value and not u_value: - # All children are removed - for child_key in p_value: - child_path = f"{current_path}.{child_key}" - deleted_paths.add(child_path) - else: - # Both are dicts, recurse deeper - deleted_paths |= find_deleted_paths( - p_value, u_value, current_path - ) - else: - # Persisted was a dict, update is not a dict -> entire branch changed - # Consider this as a full deletion of the persisted branch - deleted_paths.add(current_path) - - return deleted_paths - - -def calc_patches( - persisted: dict[str, Any], - update: dict[str, Any], - all_values: dict[str, Any], - writeables: dict[str, set[str]], -) -> tuple[dict[str, Any], set[str]]: - """ - Calculate the patches to apply to the inventory. - - Given its current state and the update to apply. - - Filters out nix-values so it doesnt matter if the anyone sends them. - - : param persisted: The current state of the inventory. - : param update: The update to apply. - : param writeable: The writeable keys. Use 'determine_writeability'. - Example: {'writeable': {'foo', 'foo.bar'}, 'non_writeable': {'foo.nix'}} - : param all_values: All values in the inventory retrieved from the flake evaluation. - - Returns a tuple with the SET and DELETE patches. - """ - persisted_flat = flatten_data(persisted) - update_flat = flatten_data(update) - all_values_flat = flatten_data(all_values) - - def is_writeable_key(key: str) -> bool: - """ - Recursively check if a key is writeable. - key "machines.machine1.deploy.targetHost" is specified but writeability is only defined for "machines" - We pop the last key and check if the parent key is writeable/non-writeable. - """ - remaining = key.split(".") - while remaining: - if ".".join(remaining) in writeables["writeable"]: - return True - if ".".join(remaining) in writeables["non_writeable"]: - return False - - remaining.pop() - - msg = f"Cannot determine writeability for key '{key}'" - raise ClanError(msg, description="F001") - - patchset = {} - for update_key, update_data in update_flat.items(): - if not is_writeable_key(update_key): - if update_data != all_values_flat.get(update_key): - msg = f"Key '{update_key}' is not writeable." - raise ClanError(msg) - continue - - if is_writeable_key(update_key): - prev_value = all_values_flat.get(update_key) - if prev_value and type(update_data) is not type(prev_value): - msg = f"Type mismatch for key '{update_key}'. Cannot update {type(all_values_flat.get(update_key))} with {type(update_data)}" - raise ClanError(msg) - - # Handle list separation - if isinstance(update_data, list): - duplicates = find_duplicates(update_data) - if duplicates: - msg = f"Key '{update_key}' contains duplicates: {duplicates}. This not supported yet." - raise ClanError(msg) - # List of current values - persisted_data = persisted_flat.get(update_key, []) - # List including nix values - all_list = all_values_flat.get(update_key, []) - nix_list = unmerge_lists(all_list, persisted_data) - if update_data != all_list: - patchset[update_key] = unmerge_lists(update_data, nix_list) - - elif update_data != persisted_flat.get(update_key, None): - patchset[update_key] = update_data - - continue - - msg = f"Cannot determine writeability for key '{update_key}'" - raise ClanError(msg) - - delete_set = find_deleted_paths(persisted, update) - - for delete_key in delete_set: - if not is_writeable_key(delete_key): - msg = f"Cannot delete: Key '{delete_key}' is not writeable." - raise ClanError(msg) - - return patchset, delete_set - - -def determine_writeability( - priorities: dict[str, Any], - defaults: dict[str, Any], - persisted: dict[str, Any], - parent_key: str = "", - parent_prio: int | None = None, - results: dict | None = None, - non_writeable: bool = False, -) -> dict[str, set[str]]: - if results is None: - results = {"writeable": set({}), "non_writeable": set({})} - - for key, value in priorities.items(): - if key == "__prio": - continue - - full_key = f"{parent_key}.{key}" if parent_key else key - - # Determine the priority for the current key - # Inherit from parent if no priority is defined - prio = value.get("__prio", None) - if prio is None: - prio = parent_prio - - # If priority is less than 100, all children are not writeable - # If the parent passed "non_writeable" earlier, this makes all children not writeable - if (prio is not None and prio < 100) or non_writeable: - results["non_writeable"].add(full_key) - if isinstance(value, dict): - determine_writeability( - value, - defaults, - {}, # Children won't be writeable, so correlation doesn't matter here - full_key, - prio, # Pass the same priority down - results, - # Recursively mark all children as non-writeable - non_writeable=True, - ) - continue - - # Check if the key is writeable otherwise - key_in_correlated = key in persisted - if prio is None: - msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable." - raise ClanError(msg) - - is_mergeable = False - if prio == 100: - default = defaults.get(key) - if isinstance(default, dict): - is_mergeable = True - if isinstance(default, list): - is_mergeable = True - if key_in_correlated: - is_mergeable = True - - is_writeable = prio > 100 or is_mergeable - - # Append the result - if is_writeable: - results["writeable"].add(full_key) - else: - results["non_writeable"].add(full_key) - - # Recursive - if isinstance(value, dict): - determine_writeability( - value, - defaults.get(key, {}), - persisted.get(key, {}), - full_key, - prio, # Pass down current priority - results, - ) - - return results - - def get_inventory_current_priority(flake: Flake) -> dict: """ Returns the current priority of the inventory values @@ -392,59 +135,6 @@ def load_inventory_json(flake: Flake) -> Inventory: return inventory -def delete_by_path(d: dict[str, Any], path: str) -> Any: - """ - Deletes the nested entry specified by a dot-separated path from the dictionary using pop(). - - :param data: The dictionary to modify. - :param path: A dot-separated string indicating the nested key to delete. - e.g., "foo.bar.baz" will attempt to delete data["foo"]["bar"]["baz"]. - - :raises KeyError: If any intermediate key is missing or not a dictionary, - or if the final key to delete is not found. - """ - if not path: - msg = "Cannot delete. Path is empty." - raise KeyError(msg) - - keys = path.split(".") - current = d - - # Navigate to the parent dictionary of the final key - for key in keys[:-1]: - if key not in current or not isinstance(current[key], dict): - msg = f"Cannot delete. Key '{path}' not found or not a dictionary '{d}'" - raise KeyError(msg) - current = current[key] - - # Attempt to pop the final key - last_key = keys[-1] - try: - value = current.pop(last_key) - except KeyError as exc: - msg = f"Cannot delete. Path '{path}' not found in data '{d}'" - raise KeyError(msg) from exc - else: - return {last_key: value} - - -def patch(d: dict[str, Any], path: str, content: Any) -> None: - """ - Update the value at a specific dot-separated path in a nested dictionary. - - If the value didn't exist before, it will be created recursively. - - :param d: The dictionary to update. - :param path: The dot-separated path to the key (e.g., 'foo.bar'). - :param content: The new value to set. - """ - keys = path.split(".") - current = d - for key in keys[:-1]: - current = current.setdefault(key, {}) - current[keys[-1]] = content - - @API.register def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) -> None: """ @@ -470,13 +160,6 @@ def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) -> ) -@dataclass -class WriteInfo: - writeables: dict[str, set[str]] - data_eval: Inventory - data_disk: Inventory - - @API.register def get_inventory_with_writeable_keys( flake: Flake, diff --git a/pkgs/clan-cli/clan_cli/tests/test_patch_inventory.py b/pkgs/clan-cli/clan_cli/tests/test_patch_inventory.py index 376c21dad..1c19b15d9 100644 --- a/pkgs/clan-cli/clan_cli/tests/test_patch_inventory.py +++ b/pkgs/clan-cli/clan_cli/tests/test_patch_inventory.py @@ -3,7 +3,7 @@ from typing import Any import pytest from clan_cli.errors import ClanError -from clan_cli.inventory import ( +from clan_lib.persist.util import ( calc_patches, delete_by_path, determine_writeability,