Merge pull request 'refactor(cli/inventory): move functions and tests into clan_lib' (#3641) from hsjobeki/clan-core:persistence-1 into main

Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/3641
This commit is contained in:
hsjobeki
2025-05-14 11:47:37 +00:00
4 changed files with 30 additions and 355 deletions

View File

@@ -10,7 +10,7 @@ from clan_lib.nix_models.inventory import Inventory
from clan_cli.cmd import CmdOut, RunOpts, run
from clan_cli.errors import ClanError
from clan_cli.flake import Flake
from clan_cli.inventory import init_inventory
from clan_cli.inventory import set_inventory
from clan_cli.nix import nix_command, nix_metadata, nix_shell
from clan_cli.templates import (
InputPrio,
@@ -108,7 +108,11 @@ def create_clan(opts: CreateOptions) -> CreateClanResponse:
response.flake_update = flake_update
if opts.initial:
init_inventory(Flake(str(opts.dest)), init=opts.initial)
set_inventory(
flake=Flake(str(opts.dest)),
inventory=opts.initial,
message="Init inventory",
)
return response

View File

@@ -4,25 +4,29 @@ DEPRECATED:
Don't use this module anymore
Instead use:
'clan_lib.persistence.inventoryStore'
'clan_lib.persist.inventoryStore'
Which is an abstraction over the inventory
Interacting with 'clan_cli.inventory' is NOT recommended and will be removed
"""
import contextlib
import json
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from clan_lib.api import API
from clan_lib.nix_models.inventory import Inventory
from clan_lib.persist.inventory_store import WriteInfo
from clan_lib.persist.util import (
calc_patches,
delete_by_path,
determine_writeability,
patch,
)
from clan_cli.cmd import run
from clan_cli.errors import ClanCmdError, ClanError
from clan_cli.errors import ClanError
from clan_cli.flake import Flake
from clan_cli.git import commit_file
from clan_cli.nix import nix_eval
@@ -67,268 +71,6 @@ def load_inventory_eval(flake_dir: Flake) -> Inventory:
return inventory
def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict:
"""
Recursively flattens a nested dictionary structure where keys are joined by the separator.
Args:
data (dict): The nested dictionary structure.
parent_key (str): The current path to the nested dictionary (used for recursion).
separator (str): The string to use for joining keys.
Returns:
dict: A flattened dictionary with all values. Directly in the root.
"""
flattened = {}
for key, value in data.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
# Recursively flatten the nested dictionary
flattened.update(flatten_data(value, new_key, separator))
else:
flattened[new_key] = value
return flattened
def unmerge_lists(all_items: list, filter_items: list) -> list:
"""
Unmerge the current list. Given a previous list.
Returns:
The other list.
"""
# Unmerge the lists
res = []
for value in all_items:
if value not in filter_items:
res.append(value)
return res
def find_duplicates(string_list: list[str]) -> list[str]:
count = Counter(string_list)
duplicates = [item for item, freq in count.items() if freq > 1]
return duplicates
def find_deleted_paths(
persisted: dict[str, Any], update: dict[str, Any], parent_key: str = ""
) -> set[str]:
"""
Recursively find keys (at any nesting level) that exist in persisted but do not
exist in update. If a nested dictionary is completely removed, return that dictionary key.
:param persisted: The original (persisted) nested dictionary.
:param update: The updated nested dictionary (some keys might be removed).
:param parent_key: The dotted path to the current dictionary's location.
:return: A set of dotted paths indicating keys or entire nested paths that were deleted.
"""
deleted_paths = set()
# Iterate over keys in persisted
for key, p_value in persisted.items():
current_path = f"{parent_key}.{key}" if parent_key else key
# Check if this key exists in update
if key not in update:
# Key doesn't exist at all -> entire branch deleted
deleted_paths.add(current_path)
else:
u_value = update[key]
# If persisted value is dict, check the update value
if isinstance(p_value, dict):
if isinstance(u_value, dict):
# If persisted dict is non-empty but updated dict is empty,
# that means everything under this branch is removed.
if p_value and not u_value:
# All children are removed
for child_key in p_value:
child_path = f"{current_path}.{child_key}"
deleted_paths.add(child_path)
else:
# Both are dicts, recurse deeper
deleted_paths |= find_deleted_paths(
p_value, u_value, current_path
)
else:
# Persisted was a dict, update is not a dict -> entire branch changed
# Consider this as a full deletion of the persisted branch
deleted_paths.add(current_path)
return deleted_paths
def calc_patches(
persisted: dict[str, Any],
update: dict[str, Any],
all_values: dict[str, Any],
writeables: dict[str, set[str]],
) -> tuple[dict[str, Any], set[str]]:
"""
Calculate the patches to apply to the inventory.
Given its current state and the update to apply.
Filters out nix-values so it doesnt matter if the anyone sends them.
: param persisted: The current state of the inventory.
: param update: The update to apply.
: param writeable: The writeable keys. Use 'determine_writeability'.
Example: {'writeable': {'foo', 'foo.bar'}, 'non_writeable': {'foo.nix'}}
: param all_values: All values in the inventory retrieved from the flake evaluation.
Returns a tuple with the SET and DELETE patches.
"""
persisted_flat = flatten_data(persisted)
update_flat = flatten_data(update)
all_values_flat = flatten_data(all_values)
def is_writeable_key(key: str) -> bool:
"""
Recursively check if a key is writeable.
key "machines.machine1.deploy.targetHost" is specified but writeability is only defined for "machines"
We pop the last key and check if the parent key is writeable/non-writeable.
"""
remaining = key.split(".")
while remaining:
if ".".join(remaining) in writeables["writeable"]:
return True
if ".".join(remaining) in writeables["non_writeable"]:
return False
remaining.pop()
msg = f"Cannot determine writeability for key '{key}'"
raise ClanError(msg, description="F001")
patchset = {}
for update_key, update_data in update_flat.items():
if not is_writeable_key(update_key):
if update_data != all_values_flat.get(update_key):
msg = f"Key '{update_key}' is not writeable."
raise ClanError(msg)
continue
if is_writeable_key(update_key):
prev_value = all_values_flat.get(update_key)
if prev_value and type(update_data) is not type(prev_value):
msg = f"Type mismatch for key '{update_key}'. Cannot update {type(all_values_flat.get(update_key))} with {type(update_data)}"
raise ClanError(msg)
# Handle list separation
if isinstance(update_data, list):
duplicates = find_duplicates(update_data)
if duplicates:
msg = f"Key '{update_key}' contains duplicates: {duplicates}. This not supported yet."
raise ClanError(msg)
# List of current values
persisted_data = persisted_flat.get(update_key, [])
# List including nix values
all_list = all_values_flat.get(update_key, [])
nix_list = unmerge_lists(all_list, persisted_data)
if update_data != all_list:
patchset[update_key] = unmerge_lists(update_data, nix_list)
elif update_data != persisted_flat.get(update_key, None):
patchset[update_key] = update_data
continue
msg = f"Cannot determine writeability for key '{update_key}'"
raise ClanError(msg)
delete_set = find_deleted_paths(persisted, update)
for delete_key in delete_set:
if not is_writeable_key(delete_key):
msg = f"Cannot delete: Key '{delete_key}' is not writeable."
raise ClanError(msg)
return patchset, delete_set
def determine_writeability(
priorities: dict[str, Any],
defaults: dict[str, Any],
persisted: dict[str, Any],
parent_key: str = "",
parent_prio: int | None = None,
results: dict | None = None,
non_writeable: bool = False,
) -> dict[str, set[str]]:
if results is None:
results = {"writeable": set({}), "non_writeable": set({})}
for key, value in priorities.items():
if key == "__prio":
continue
full_key = f"{parent_key}.{key}" if parent_key else key
# Determine the priority for the current key
# Inherit from parent if no priority is defined
prio = value.get("__prio", None)
if prio is None:
prio = parent_prio
# If priority is less than 100, all children are not writeable
# If the parent passed "non_writeable" earlier, this makes all children not writeable
if (prio is not None and prio < 100) or non_writeable:
results["non_writeable"].add(full_key)
if isinstance(value, dict):
determine_writeability(
value,
defaults,
{}, # Children won't be writeable, so correlation doesn't matter here
full_key,
prio, # Pass the same priority down
results,
# Recursively mark all children as non-writeable
non_writeable=True,
)
continue
# Check if the key is writeable otherwise
key_in_correlated = key in persisted
if prio is None:
msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable."
raise ClanError(msg)
is_mergeable = False
if prio == 100:
default = defaults.get(key)
if isinstance(default, dict):
is_mergeable = True
if isinstance(default, list):
is_mergeable = True
if key_in_correlated:
is_mergeable = True
is_writeable = prio > 100 or is_mergeable
# Append the result
if is_writeable:
results["writeable"].add(full_key)
else:
results["non_writeable"].add(full_key)
# Recursive
if isinstance(value, dict):
determine_writeability(
value,
defaults.get(key, {}),
persisted.get(key, {}),
full_key,
prio, # Pass down current priority
results,
)
return results
def get_inventory_current_priority(flake: Flake) -> dict:
"""
Returns the current priority of the inventory values
@@ -393,59 +135,6 @@ def load_inventory_json(flake: Flake) -> Inventory:
return inventory
def delete_by_path(d: dict[str, Any], path: str) -> Any:
"""
Deletes the nested entry specified by a dot-separated path from the dictionary using pop().
:param data: The dictionary to modify.
:param path: A dot-separated string indicating the nested key to delete.
e.g., "foo.bar.baz" will attempt to delete data["foo"]["bar"]["baz"].
:raises KeyError: If any intermediate key is missing or not a dictionary,
or if the final key to delete is not found.
"""
if not path:
msg = "Cannot delete. Path is empty."
raise KeyError(msg)
keys = path.split(".")
current = d
# Navigate to the parent dictionary of the final key
for key in keys[:-1]:
if key not in current or not isinstance(current[key], dict):
msg = f"Cannot delete. Key '{path}' not found or not a dictionary '{d}'"
raise KeyError(msg)
current = current[key]
# Attempt to pop the final key
last_key = keys[-1]
try:
value = current.pop(last_key)
except KeyError as exc:
msg = f"Cannot delete. Path '{path}' not found in data '{d}'"
raise KeyError(msg) from exc
else:
return {last_key: value}
def patch(d: dict[str, Any], path: str, content: Any) -> None:
"""
Update the value at a specific dot-separated path in a nested dictionary.
If the value didn't exist before, it will be created recursively.
:param d: The dictionary to update.
:param path: The dot-separated path to the key (e.g., 'foo.bar').
:param content: The new value to set.
"""
keys = path.split(".")
current = d
for key in keys[:-1]:
current = current.setdefault(key, {})
current[keys[-1]] = content
@API.register
def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) -> None:
"""
@@ -471,13 +160,6 @@ def patch_inventory_with(flake: Flake, section: str, content: dict[str, Any]) ->
)
@dataclass
class WriteInfo:
writeables: dict[str, set[str]]
data_eval: Inventory
data_disk: Inventory
@API.register
def get_inventory_with_writeable_keys(
flake: Flake,
@@ -560,22 +242,6 @@ def delete(flake: Flake, delete_set: set[str]) -> None:
)
def init_inventory(flake: Flake, init: Inventory | None = None) -> None:
inventory = None
# Try reading the current flake
if init is None:
with contextlib.suppress(ClanCmdError):
inventory = load_inventory_eval(flake)
if init is not None:
inventory = init
# Write inventory.json file
if inventory is not None:
# Persist creates a commit message for each change
set_inventory(inventory, flake, "Init inventory")
@API.register
def get_inventory(flake: Flake) -> Inventory:
return load_inventory_eval(flake)

View File

@@ -3,7 +3,8 @@ from typing import Any
import pytest
from clan_cli.errors import ClanError
from clan_cli.inventory import (
from clan_lib.persist.util import (
calc_patches,
delete_by_path,
determine_writeability,

View File

@@ -7,7 +7,7 @@ name = "clan"
description = "clan cli tool"
dynamic = ["version"]
scripts = { clan = "clan_cli:main" }
license = {text = "MIT"}
license = { text = "MIT" }
[project.urls]
Homepage = "https://clan.lol/"
@@ -20,12 +20,12 @@ exclude = ["clan_cli.nixpkgs*", "result"]
[tool.setuptools.package-data]
clan_cli = [
"**/allowed-packages.json",
"py.typed",
"templates/**/*",
"vms/mimetypes/**/*",
"webui/assets/**/*",
"flash/*.sh"
"**/allowed-packages.json",
"py.typed",
"templates/**/*",
"vms/mimetypes/**/*",
"webui/assets/**/*",
"flash/*.sh",
]
[tool.pytest.ini_options]
@@ -44,6 +44,12 @@ markers = ["impure", "with_core"]
filterwarnings = "default::ResourceWarning"
python_files = ["test_*.py", "*_test.py"]
# TODO: cov seems to conflict with xdist
# [tool.coverage.run]
# branch = true
# source = ["clan_lib"]
# omit = ["*/tests/*", "*/test_*.py", "*/*_test.py", "*/conftest.py", "docs.py"]
[tool.mypy]
python_version = "3.12"
warn_redundant_casts = true
@@ -51,5 +57,3 @@ disallow_untyped_calls = true
disallow_untyped_defs = true
no_implicit_optional = true
exclude = "clan_cli.nixpkgs"