inventory/api: init smart update for inventory

This commit is contained in:
Johannes Kirschbauer
2024-12-06 11:07:36 +01:00
parent a032c446e1
commit 6dd1ecb044
2 changed files with 315 additions and 13 deletions

View File

@@ -14,6 +14,7 @@ Operate on the returned inventory to make changes
import contextlib
import json
from collections import Counter
from pathlib import Path
from typing import Any
@@ -94,7 +95,6 @@ def load_inventory_eval(flake_dir: str | Path) -> Inventory:
def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict:
"""
Recursively flattens a nested dictionary structure where keys are joined by the separator.
The flattened dictionary contains only entries with "__prio".
Args:
data (dict): The nested dictionary structure.
@@ -102,24 +102,23 @@ def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict
separator (str): The string to use for joining keys.
Returns:
dict: A flattened dictionary with "__prio" values.
dict: A flattened dictionary with all values. Directly in the root.
"""
flattened = {}
for key, value in data.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict) and "__prio" in value:
flattened[new_key] = {"__prio": value["__prio"]}
if isinstance(value, dict):
# Recursively flatten the nested dictionary
flattened.update(flatten_data(value, new_key, separator))
else:
flattened[new_key] = value
return flattened
def unmerge_lists(curr: list, prev: list) -> list:
def unmerge_lists(all_items: list, filter_items: list) -> list:
"""
Unmerge the current list. Given a previous list.
@@ -127,12 +126,80 @@ def unmerge_lists(curr: list, prev: list) -> list:
The other list.
"""
# Unmerge the lists
unmerged = []
for value in curr:
if value not in prev:
unmerged.append(value)
res = []
for value in all_items:
if value not in filter_items:
res.append(value)
return unmerged
return res
def find_duplicates(string_list: list[str]) -> list[str]:
count = Counter(string_list)
duplicates = [item for item, freq in count.items() if freq > 1]
return duplicates
def calc_patches(
persisted: dict, update: dict, all_values: dict, writeables: dict
) -> dict[str, Any]:
"""
Calculate the patches to apply to the inventory.
Given its current state and the update to apply.
Filters out nix-values so it doesnt matter if the anyone sends them.
: param persisted: The current state of the inventory.
: param update: The update to apply.
: param writeable: The writeable keys. Use 'determine_writeability'.
Example: {'writeable': {'foo', 'foo.bar'}, 'non_writeable': {'foo.nix'}}
: param all_values: All values in the inventory retrieved from the flake evaluation.
"""
persisted_flat = flatten_data(persisted)
update_flat = flatten_data(update)
all_values_flat = flatten_data(all_values)
patchset = {}
for update_key, update_data in update_flat.items():
if update_key in writeables["non_writeable"]:
if update_data != all_values_flat.get(update_key):
msg = f"Key '{update_key}' is not writeable."
raise ClanError(msg)
continue
if update_key in writeables["writeable"]:
if type(update_data) is not type(all_values_flat.get(update_key)):
msg = f"Type mismatch for key '{update_key}'. Cannot update {type(all_values_flat.get(update_key))} with {type(update_data)}"
raise ClanError(msg)
# Handle list seperation
if isinstance(update_data, list):
duplicates = find_duplicates(update_data)
if duplicates:
msg = f"Key '{update_key}' contains duplicates: {duplicates}. This not supported yet."
raise ClanError(msg)
# List of current values
persisted_data = persisted_flat.get(update_key, [])
# List including nix values
all_list = all_values_flat.get(update_key, [])
nix_list = unmerge_lists(all_list, persisted_data)
if update_data != all_list:
patchset[update_key] = unmerge_lists(update_data, nix_list)
elif update_data != persisted_flat.get(update_key, None):
patchset[update_key] = update_data
continue
if update_key not in all_values_flat:
msg = f"Key '{update_key}' cannot be set. It does not exist."
raise ClanError(msg)
msg = f"Cannot determine writeability for key '{update_key}'"
raise ClanError(msg)
return patchset
def determine_writeability(