inventory/api: init smart update for inventory
This commit is contained in:
@@ -14,6 +14,7 @@ Operate on the returned inventory to make changes
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -94,7 +95,6 @@ def load_inventory_eval(flake_dir: str | Path) -> Inventory:
|
||||
def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict:
|
||||
"""
|
||||
Recursively flattens a nested dictionary structure where keys are joined by the separator.
|
||||
The flattened dictionary contains only entries with "__prio".
|
||||
|
||||
Args:
|
||||
data (dict): The nested dictionary structure.
|
||||
@@ -102,24 +102,23 @@ def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict
|
||||
separator (str): The string to use for joining keys.
|
||||
|
||||
Returns:
|
||||
dict: A flattened dictionary with "__prio" values.
|
||||
dict: A flattened dictionary with all values. Directly in the root.
|
||||
"""
|
||||
flattened = {}
|
||||
|
||||
for key, value in data.items():
|
||||
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
||||
|
||||
if isinstance(value, dict) and "__prio" in value:
|
||||
flattened[new_key] = {"__prio": value["__prio"]}
|
||||
|
||||
if isinstance(value, dict):
|
||||
# Recursively flatten the nested dictionary
|
||||
flattened.update(flatten_data(value, new_key, separator))
|
||||
else:
|
||||
flattened[new_key] = value
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
def unmerge_lists(curr: list, prev: list) -> list:
|
||||
def unmerge_lists(all_items: list, filter_items: list) -> list:
|
||||
"""
|
||||
Unmerge the current list. Given a previous list.
|
||||
|
||||
@@ -127,12 +126,80 @@ def unmerge_lists(curr: list, prev: list) -> list:
|
||||
The other list.
|
||||
"""
|
||||
# Unmerge the lists
|
||||
unmerged = []
|
||||
for value in curr:
|
||||
if value not in prev:
|
||||
unmerged.append(value)
|
||||
res = []
|
||||
for value in all_items:
|
||||
if value not in filter_items:
|
||||
res.append(value)
|
||||
|
||||
return unmerged
|
||||
return res
|
||||
|
||||
|
||||
def find_duplicates(string_list: list[str]) -> list[str]:
|
||||
count = Counter(string_list)
|
||||
duplicates = [item for item, freq in count.items() if freq > 1]
|
||||
return duplicates
|
||||
|
||||
|
||||
def calc_patches(
|
||||
persisted: dict, update: dict, all_values: dict, writeables: dict
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate the patches to apply to the inventory.
|
||||
|
||||
Given its current state and the update to apply.
|
||||
|
||||
Filters out nix-values so it doesnt matter if the anyone sends them.
|
||||
|
||||
: param persisted: The current state of the inventory.
|
||||
: param update: The update to apply.
|
||||
: param writeable: The writeable keys. Use 'determine_writeability'.
|
||||
Example: {'writeable': {'foo', 'foo.bar'}, 'non_writeable': {'foo.nix'}}
|
||||
: param all_values: All values in the inventory retrieved from the flake evaluation.
|
||||
"""
|
||||
persisted_flat = flatten_data(persisted)
|
||||
update_flat = flatten_data(update)
|
||||
all_values_flat = flatten_data(all_values)
|
||||
|
||||
patchset = {}
|
||||
for update_key, update_data in update_flat.items():
|
||||
if update_key in writeables["non_writeable"]:
|
||||
if update_data != all_values_flat.get(update_key):
|
||||
msg = f"Key '{update_key}' is not writeable."
|
||||
raise ClanError(msg)
|
||||
continue
|
||||
|
||||
if update_key in writeables["writeable"]:
|
||||
if type(update_data) is not type(all_values_flat.get(update_key)):
|
||||
msg = f"Type mismatch for key '{update_key}'. Cannot update {type(all_values_flat.get(update_key))} with {type(update_data)}"
|
||||
raise ClanError(msg)
|
||||
|
||||
# Handle list seperation
|
||||
if isinstance(update_data, list):
|
||||
duplicates = find_duplicates(update_data)
|
||||
if duplicates:
|
||||
msg = f"Key '{update_key}' contains duplicates: {duplicates}. This not supported yet."
|
||||
raise ClanError(msg)
|
||||
# List of current values
|
||||
persisted_data = persisted_flat.get(update_key, [])
|
||||
# List including nix values
|
||||
all_list = all_values_flat.get(update_key, [])
|
||||
nix_list = unmerge_lists(all_list, persisted_data)
|
||||
if update_data != all_list:
|
||||
patchset[update_key] = unmerge_lists(update_data, nix_list)
|
||||
|
||||
elif update_data != persisted_flat.get(update_key, None):
|
||||
patchset[update_key] = update_data
|
||||
|
||||
continue
|
||||
|
||||
if update_key not in all_values_flat:
|
||||
msg = f"Key '{update_key}' cannot be set. It does not exist."
|
||||
raise ClanError(msg)
|
||||
|
||||
msg = f"Cannot determine writeability for key '{update_key}'"
|
||||
raise ClanError(msg)
|
||||
|
||||
return patchset
|
||||
|
||||
|
||||
def determine_writeability(
|
||||
|
||||
Reference in New Issue
Block a user