refactor(lib/inventory): use util functions from clan_lib

This commit is contained in:
Johannes Kirschbauer
2025-05-14 10:50:47 +02:00
parent bdb075be11
commit 91034b66bf
2 changed files with 9 additions and 326 deletions

View File

@@ -4,7 +4,7 @@ DEPRECATED:
Don't use this module anymore Don't use this module anymore
Instead use: Instead use:
'clan_lib.persistence.inventoryStore' 'clan_lib.persist.inventoryStore'
Which is an abstraction over the inventory Which is an abstraction over the inventory
@@ -12,13 +12,18 @@ Interacting with 'clan_cli.inventory' is NOT recommended and will be removed
""" """
import json import json
from collections import Counter
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from clan_lib.api import API from clan_lib.api import API
from clan_lib.nix_models.inventory import Inventory from clan_lib.nix_models.inventory import Inventory
from clan_lib.persist.inventory_store import WriteInfo
from clan_lib.persist.util import (
calc_patches,
delete_by_path,
determine_writeability,
patch,
)
from clan_cli.cmd import run from clan_cli.cmd import run
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
@@ -66,268 +71,6 @@ def load_inventory_eval(flake_dir: Flake) -> Inventory:
return inventory return inventory
def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict:
"""
Recursively flattens a nested dictionary structure where keys are joined by the separator.
Args:
data (dict): The nested dictionary structure.
parent_key (str): The current path to the nested dictionary (used for recursion).
separator (str): The string to use for joining keys.
Returns:
dict: A flattened dictionary with all values. Directly in the root.
"""
flattened = {}
for key, value in data.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
# Recursively flatten the nested dictionary
flattened.update(flatten_data(value, new_key, separator))
else:
flattened[new_key] = value
return flattened
def unmerge_lists(all_items: list, filter_items: list) -> list:
"""
Unmerge the current list. Given a previous list.
Returns:
The other list.
"""
# Unmerge the lists
res = []
for value in all_items:
if value not in filter_items:
res.append(value)
return res
def find_duplicates(string_list: list[str]) -> list[str]:
count = Counter(string_list)
duplicates = [item for item, freq in count.items() if freq > 1]
return duplicates
def find_deleted_paths(
persisted: dict[str, Any], update: dict[str, Any], parent_key: str = ""
) -> set[str]:
"""
Recursively find keys (at any nesting level) that exist in persisted but do not
exist in update. If a nested dictionary is completely removed, return that dictionary key.
:param persisted: The original (persisted) nested dictionary.
:param update: The updated nested dictionary (some keys might be removed).
:param parent_key: The dotted path to the current dictionary's location.
:return: A set of dotted paths indicating keys or entire nested paths that were deleted.
"""
deleted_paths = set()
# Iterate over keys in persisted
for key, p_value in persisted.items():
current_path = f"{parent_key}.{key}" if parent_key else key
# Check if this key exists in update
if key not in update:
# Key doesn't exist at all -> entire branch deleted
deleted_paths.add(current_path)
else:
u_value = update[key]
# If persisted value is dict, check the update value
if isinstance(p_value, dict):
if isinstance(u_value, dict):
# If persisted dict is non-empty but updated dict is empty,
# that means everything under this branch is removed.
if p_value and not u_value:
# All children are removed
for child_key in p_value:
child_path = f"{current_path}.{child_key}"
deleted_paths.add(child_path)
else:
# Both are dicts, recurse deeper
deleted_paths |= find_deleted_paths(
p_value, u_value, current_path
)
else:
# Persisted was a dict, update is not a dict -> entire branch changed
# Consider this as a full deletion of the persisted branch
deleted_paths.add(current_path)
return deleted_paths
def calc_patches(
persisted: dict[str, Any],
update: dict[str, Any],
all_values: dict[str, Any],
writeables: dict[str, set[str]],
) -> tuple[dict[str, Any], set[str]]:
"""
Calculate the patches to apply to the inventory.
Given its current state and the update to apply.
Filters out nix-values so it doesnt matter if the anyone sends them.
: param persisted: The current state of the inventory.
: param update: The update to apply.
: param writeable: The writeable keys. Use 'determine_writeability'.
Example: {'writeable': {'foo', 'foo.bar'}, 'non_writeable': {'foo.nix'}}
: param all_values: All values in the inventory retrieved from the flake evaluation.
Returns a tuple with the SET and DELETE patches.
"""
persisted_flat = flatten_data(persisted)
update_flat = flatten_data(update)
all_values_flat = flatten_data(all_values)
def is_writeable_key(key: str) -> bool:
"""
Recursively check if a key is writeable.
key "machines.machine1.deploy.targetHost" is specified but writeability is only defined for "machines"
We pop the last key and check if the parent key is writeable/non-writeable.
"""
remaining = key.split(".")
while remaining:
if ".".join(remaining) in writeables["writeable"]:
return True
if ".".join(remaining) in writeables["non_writeable"]:
return False
remaining.pop()
msg = f"Cannot determine writeability for key '{key}'"
raise ClanError(msg, description="F001")
patchset = {}
for update_key, update_data in update_flat.items():
if not is_writeable_key(update_key):
if update_data != all_values_flat.get(update_key):
msg = f"Key '{update_key}' is not writeable."
raise ClanError(msg)
continue
if is_writeable_key(update_key):
prev_value = all_values_flat.get(update_key)
if prev_value and type(update_data) is not type(prev_value):
msg = f"Type mismatch for key '{update_key}'. Cannot update {type(all_values_flat.get(update_key))} with {type(update_data)}"
raise ClanError(msg)
# Handle list separation
if isinstance(update_data, list):
duplicates = find_duplicates(update_data)
if duplicates:
msg = f"Key '{update_key}' contains duplicates: {duplicates}. This not supported yet."
raise ClanError(msg)
# List of current values
persisted_data = persisted_flat.get(update_key, [])
# List including nix values
all_list = all_values_flat.get(update_key, [])
nix_list = unmerge_lists(all_list, persisted_data)
if update_data != all_list:
patchset[update_key] = unmerge_lists(update_data, nix_list)
elif update_data != persisted_flat.get(update_key, None):
patchset[update_key] = update_data
continue
msg = f"Cannot determine writeability for key '{update_key}'"
raise ClanError(msg)
delete_set = find_deleted_paths(persisted, update)
for delete_key in delete_set:
if not is_writeable_key(delete_key):
msg = f"Cannot delete: Key '{delete_key}' is not writeable."
raise ClanError(msg)
return patchset, delete_set
def determine_writeability(
priorities: dict[str, Any],
defaults: dict[str, Any],
persisted: dict[str, Any],
parent_key: str = "",
parent_prio: int | None = None,
results: dict | None = None,
non_writeable: bool = False,
) -> dict[str, set[str]]:
if results is None:
results = {"writeable": set({}), "non_writeable": set({})}
for key, value in priorities.items():
if key == "__prio":
continue
full_key = f"{parent_key}.{key}" if parent_key else key
# Determine the priority for the current key
# Inherit from parent if no priority is defined
prio = value.get("__prio", None)
if prio is None:
prio = parent_prio
# If priority is less than 100, all children are not writeable
# If the parent passed "non_writeable" earlier, this makes all children not writeable
if (prio is not None and prio < 100) or non_writeable:
results["non_writeable"].add(full_key)
if isinstance(value, dict):
determine_writeability(
value,
defaults,
{}, # Children won't be writeable, so correlation doesn't matter here
full_key,
prio, # Pass the same priority down
results,
# Recursively mark all children as non-writeable
non_writeable=True,
)
continue
# Check if the key is writeable otherwise
key_in_correlated = key in persisted
if prio is None:
msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable."
raise ClanError(msg)
is_mergeable = False
if prio == 100:
default = defaults.get(key)
if isinstance(default, dict):
is_mergeable = True
if isinstance(default, list):
is_mergeable = True
if key_in_correlated:
is_mergeable = True
is_writeable = prio > 100 or is_mergeable
# Append the result
if is_writeable:
results["writeable"].add(full_key)
else:
results["non_writeable"].add(full_key)
# Recursive
if isinstance(value, dict):
determine_writeability(
value,
defaults.get(key, {}),
persisted.get(key, {}),
full_key,
prio, # Pass down current priority
results,
)
return results
def get_inventory_current_priority(flake: Flake) -> dict: def get_inventory_current_priority(flake: Flake) -> dict:
""" """
Returns the current priority of the inventory values Returns the current priority of the inventory values
@@ -392,59 +135,6 @@ def load_inventory_json(flake: Flake) -> Inventory:
return inventory return inventory
def delete_by_path(d: dict[str, Any], path: str) -> Any:
"""
Deletes the nested entry specified by a dot-separated path from the dictionary using pop().
:param data: The dictionary to modify.
:param path: A dot-separated string indicating the nested key to delete.
e.g., "foo.bar.baz" will attempt to delete data["foo"]["bar"]["baz"].
:raises KeyError: If any intermediate key is missing or not a dictionary,
or if the final key to delete is not found.
"""
if not path:
msg = "Cannot delete. Path is empty."
raise KeyError(msg)
keys = path.split(".")
current = d
# Navigate to the parent dictionary of the final key
for key in keys[:-1]:
if key not in current or not isinstance(current[key], dict):
msg = f"Cannot delete. Key '{path}' not found or not a dictionary '{d}'"
raise KeyError(msg)
current = current[key]
# Attempt to pop the final key
last_key = keys[-1]
try:
value = current.pop(last_key)
except KeyError as exc:
msg = f"Cannot delete. Path '{path}' not found in data '{d}'"
raise KeyError(msg) from exc
else:
return {last_key: value}
def patch(d: dict[str, Any], path: str, content: Any) -> None:
"""
Update the value at a specific dot-separated path in a nested dictionary.
If the value didn't exist before, it will be created recursively.
:param d: The dictionary to update.
:param path: The dot-separated path to the key (e.g., 'foo.bar').
:param content: The new value to set.
"""
keys = path.split(".")
current = d
for key in keys[:-1]:
current = current.setdefault(key, {})
current[keys[-1]] = content
@API.register @API.register
def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) -> None: def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) -> None:
""" """
@@ -470,13 +160,6 @@ def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) ->
) )
@dataclass
class WriteInfo:
writeables: dict[str, set[str]]
data_eval: Inventory
data_disk: Inventory
@API.register @API.register
def get_inventory_with_writeable_keys( def get_inventory_with_writeable_keys(
flake: Flake, flake: Flake,

View File

@@ -3,7 +3,7 @@ from typing import Any
import pytest import pytest
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.inventory import ( from clan_lib.persist.util import (
calc_patches, calc_patches,
delete_by_path, delete_by_path,
determine_writeability, determine_writeability,