From fe2cfd3b37d1f4ce356c441f3ad11a75a597bec7 Mon Sep 17 00:00:00 2001 From: lassulus Date: Tue, 28 Jan 2025 10:15:10 +0100 Subject: [PATCH] clan-cli: add a Flake class with caching --- lib/build-clan/interface.nix | 1 + lib/build-clan/module.nix | 4 + lib/default.nix | 1 + lib/select.nix | 68 +++++ pkgs/clan-cli/clan_cli/flake.py | 312 ++++++++++++++++++++++ pkgs/clan-cli/tests/test_flake_caching.py | 15 ++ 6 files changed, 401 insertions(+) create mode 100644 lib/select.nix create mode 100644 pkgs/clan-cli/clan_cli/flake.py create mode 100644 pkgs/clan-cli/tests/test_flake_caching.py diff --git a/lib/build-clan/interface.nix b/lib/build-clan/interface.nix index eff0d0627..c1e534296 100644 --- a/lib/build-clan/interface.nix +++ b/lib/build-clan/interface.nix @@ -130,6 +130,7 @@ in clanModules = lib.mkOption { type = lib.types.raw; }; source = lib.mkOption { type = lib.types.raw; }; meta = lib.mkOption { type = lib.types.raw; }; + lib = lib.mkOption { type = lib.types.raw; }; all-machines-json = lib.mkOption { type = lib.types.raw; }; machines = lib.mkOption { type = lib.types.raw; }; machinesFunc = lib.mkOption { type = lib.types.raw; }; diff --git a/lib/build-clan/module.nix b/lib/build-clan/module.nix index e8a5be5fa..16d22f158 100644 --- a/lib/build-clan/module.nix +++ b/lib/build-clan/module.nix @@ -217,6 +217,10 @@ in templates = config.templates; inventory = config.inventory; meta = config.inventory.meta; + lib = { + inherit (clan-core.lib) select; + }; + source = "${clan-core}"; # machine specifics diff --git a/lib/default.nix b/lib/default.nix index b48b4582e..692579a99 100644 --- a/lib/default.nix +++ b/lib/default.nix @@ -21,4 +21,5 @@ in inherit lib; self = clan-core; }; + select = import ./select.nix; } diff --git a/lib/select.nix b/lib/select.nix new file mode 100644 index 000000000..d291f50d0 --- /dev/null +++ b/lib/select.nix @@ -0,0 +1,68 @@ +let + recursiveSelect = + selectorIndex: selectorList: target: + let + selector = builtins.elemAt selectorList selectorIndex; + in + + # selector is empty, we are done + if selectorIndex + 1 > builtins.length selectorList then + target + + else if builtins.isList target then + # support bla.* for lists and recurse into all elements + if selector == "*" then + builtins.map (v: recursiveSelect (selectorIndex + 1) selectorList v) target + # support bla.3 for lists and recurse into the 4th element + else if (builtins.match "[[:digit:]]*" selector) == [ ] then + recursiveSelect (selectorIndex + 1) selectorList ( + builtins.elemAt target (builtins.fromJSON selector) + ) + else + throw "only * or a number is allowed in list selector" + + else if builtins.isAttrs target then + # handle the case bla.x.*.z where x is an attrset and we recurse into all elements + if selector == "*" then + builtins.mapAttrs (_: v: recursiveSelect (selectorIndex + 1) selectorList v) target + # support bla.{x,y,z}.world where we get world from each of x, y and z + else if (builtins.match ''^\{([^}]*)}$'' selector) != null then + let + attrsAsList = ( + builtins.filter (x: !builtins.isList x) ( + builtins.split "," (builtins.head (builtins.match ''^\{([^}]*)}$'' selector)) + ) + ); + dummyAttrSet = builtins.listToAttrs ( + map (x: { + name = x; + value = null; + }) attrsAsList + ); + filteredAttrs = builtins.intersectAttrs dummyAttrSet target; + in + builtins.mapAttrs (_: v: recursiveSelect (selectorIndex + 1) selectorList v) filteredAttrs + else + recursiveSelect (selectorIndex + 1) selectorList (builtins.getAttr selector target) + else + throw "Expected a list or an attrset"; + + parseSelector = + selector: + let + splitByQuote = x: builtins.filter (x: !builtins.isList x) (builtins.split ''"'' x); + splitByDot = + x: + builtins.filter (x: x != "") ( + map (builtins.replaceStrings [ "." ] [ "" ]) ( + builtins.filter (x: !builtins.isList x) (builtins.split ''\.'' x) + ) + ); + handleQuoted = + x: if x == [ ] then [ ] else [ (builtins.head x) ] ++ handleUnquoted (builtins.tail x); + handleUnquoted = + x: if x == [ ] then [ ] else splitByDot (builtins.head x) ++ handleQuoted (builtins.tail x); + in + handleUnquoted (splitByQuote selector); +in +selector: target: recursiveSelect 0 (parseSelector selector) target diff --git a/pkgs/clan-cli/clan_cli/flake.py b/pkgs/clan-cli/clan_cli/flake.py new file mode 100644 index 000000000..7546a7a3e --- /dev/null +++ b/pkgs/clan-cli/clan_cli/flake.py @@ -0,0 +1,312 @@ +import json +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from clan_cli.cmd import run +from clan_cli.errors import ClanError +from clan_cli.nix import nix_build, nix_config + +log = logging.getLogger(__name__) + + +class AllSelector: + pass + + +Selector = str | int | AllSelector | set[int] | set[str] + + +def split_selector(selector: str) -> list[Selector]: + """ + takes a string and returns a list of selectors. + + a selector can be: + - a string, which is a key in a dict + - an integer, which is an index in a list + - a set of strings, which are keys in a dict + - a set of integers, which are indices in a list + - a quoted string, which is a key in a dict + - the string "*", which selects all elements in a list or dict + """ + pattern = r'"[^"]*"|[^.]+' + matches = re.findall(pattern, selector) + + # Extract the matched groups (either quoted or unquoted parts) + selectors: list[Selector] = [] + for selector in matches: + if selector == "*": + selectors.append(AllSelector()) + elif selector.isdigit(): + selectors.append({int(selector)}) + elif selector.startswith("{") and selector.endswith("}"): + sub_selectors = set(selector[1:-1].split(",")) + selectors.append(sub_selectors) + elif selector.startswith('"') and selector.endswith('"'): + selectors.append(selector[1:-1]) + else: + selectors.append(selector) + + return selectors + + +@dataclass +class FlakeCacheEntry: + """ + a recrusive structure to store the cache, with a value and a selector + """ + + def __init__( + self, + value: str | float | dict[str, Any] | list[Any], + selectors: list[Selector], + ) -> None: + self.value: str | float | int | dict[str | int, FlakeCacheEntry] + self.selector: Selector + + if selectors == []: + self.selector = AllSelector() + elif isinstance(selectors[0], str): + self.selector = selectors[0] + self.value = {self.selector: FlakeCacheEntry(value, selectors[1:])} + return + else: + self.selector = selectors[0] + + if isinstance(value, dict): + if isinstance(self.selector, set): + if not all(isinstance(v, str) for v in self.selector): + msg = "Cannot index dict with non-str set" + raise ValueError(msg) + self.value = {} + for key, value_ in value.items(): + self.value[key] = FlakeCacheEntry(value_, selectors[1:]) + + elif isinstance(value, list): + if isinstance(self.selector, int): + if len(value) != 1: + msg = "Cannot index list with int selector when value is not singleton" + raise ValueError(msg) + self.value = {} + self.value[int(self.selector)] = FlakeCacheEntry( + value[0], selectors[1:] + ) + if isinstance(self.selector, set): + if all(isinstance(v, int) for v in self.selector): + self.value = {} + for i, v in enumerate(self.selector): + assert isinstance(v, int) + self.value[int(v)] = FlakeCacheEntry(value[i], selectors[1:]) + else: + msg = "Cannot index list with non-int set" + raise ValueError(msg) + elif isinstance(self.selector, AllSelector): + self.value = {} + for i, v in enumerate(value): + if isinstance(v, dict | list | str | float | int): + self.value[i] = FlakeCacheEntry(v, selectors[1:]) + else: + msg = f"expected integer selector or all for type list, but got {type(selectors[0])}" + raise TypeError(msg) + + elif isinstance(value, (str | float | int)): + self.value = value + + def insert( + self, value: str | float | dict[str, Any] | list[Any], selectors: list[Selector] + ) -> None: + selector: Selector + if selectors == []: + selector = AllSelector() + else: + selector = selectors[0] + + if isinstance(selector, str): + if isinstance(self.value, dict): + if selector in self.value: + self.value[selector].insert(value, selectors[1:]) + else: + self.value[selector] = FlakeCacheEntry(value, selectors[1:]) + return + msg = f"Cannot insert {selector} into non dict value" + raise TypeError(msg) + + if isinstance(selector, AllSelector): + self.selector = AllSelector() + elif isinstance(self.selector, set) and isinstance(selector, set): + self.selector.union(selector) + + if isinstance(self.value, dict) and isinstance(value, dict): + for key, value_ in value.items(): + if key in self.value: + self.value[key].insert(value_, selectors[1:]) + else: + self.value[key] = FlakeCacheEntry(value_, selectors[1:]) + + elif isinstance(self.value, dict) and isinstance(value, list): + if isinstance(selector, set): + if not all(isinstance(v, int) for v in selector): + msg = "Cannot list with non-int set" + raise ValueError(msg) + for realindex, requested_index in enumerate(selector): + assert isinstance(requested_index, int) + if requested_index in self.value: + self.value[requested_index].insert( + value[realindex], selectors[1:] + ) + elif isinstance(selector, AllSelector): + for index, v in enumerate(value): + if index in self.value: + self.value[index].insert(v, selectors[1:]) + else: + self.value[index] = FlakeCacheEntry(v, selectors[1:]) + elif isinstance(selector, int): + if selector in self.value: + self.value[selector].insert(value[0], selectors[1:]) + else: + self.value[selector] = FlakeCacheEntry(value[0], selectors[1:]) + + elif isinstance(value, (str | float | int)): + if self.value: + if self.value != value: + msg = "value mismatch in cache, something is fishy" + raise TypeError(msg) + else: + msg = f"Cannot insert value of type {type(value)} into cache" + raise TypeError(msg) + + def is_cached(self, selectors: list[Selector]) -> bool: + selector: Selector + if selectors == []: + selector = AllSelector() + else: + selector = selectors[0] + + if isinstance(self.value, str | float | int): + return selectors == [] + if isinstance(selector, AllSelector): + if isinstance(self.selector, AllSelector): + return all( + self.value[sel].is_cached(selectors[1:]) for sel in self.value + ) + # TODO: check if we already have all the keys anyway? + print("not cached because self.selector is not all") + return False + if ( + isinstance(selector, set) + and isinstance(self.selector, set) + and isinstance(self.value, dict) + ): + if not selector.issubset(self.selector): + print("not cached because selector is not subset of self.selector") + return False + return all(self.value[sel].is_cached(selectors[1:]) for sel in selector) + if isinstance(selector, str | int) and isinstance(self.value, dict): + if selector in self.value: + return self.value[selector].is_cached(selectors[1:]) + print("not cached because selector is not in self.value") + return False + print("not cached because of unknown reason") + return False + + def select(self, selectors: list[Selector]) -> Any: + selector: Selector + if selectors == []: + selector = AllSelector() + else: + selector = selectors[0] + + if isinstance(self.value, str | float | int): + return self.value + if isinstance(self.value, dict): + if isinstance(selector, AllSelector): + return {k: v.select(selectors[1:]) for k, v in self.value.items()} + if isinstance(selector, set): + return { + k: v.select(selectors[1:]) + for k, v in self.value.items() + if k in selector + } + if isinstance(selector, str | int): + return self.value[selector].select(selectors[1:]) + msg = f"Cannot select {selector} from type {type(self.value)}" + raise TypeError(msg) + + def __getitem__(self, name: str) -> "FlakeCacheEntry": + if isinstance(self.value, dict): + return self.value[name] + msg = f"value is a {type(self.value)}, so cannot subscribe" + raise TypeError(msg) + + def __repr__(self) -> str: + if isinstance(self.value, dict): + return f"FlakeCache {{{', '.join([str(k) for k in self.value])}}}" + return f"FlakeCache {self.value}" + + +class FlakeCache: + """ + an in-memory cache for flake outputs, uses a recursive FLakeCacheEntry structure + """ + + def __init__(self) -> None: + self.cache: FlakeCacheEntry = FlakeCacheEntry({}, []) + + def insert(self, data: dict[str, Any], selector_str: str) -> None: + if selector_str: + selectors = split_selector(selector_str) + else: + selectors = [] + + self.cache.insert(data, selectors) + + def select(self, selector_str: str) -> Any: + selectors = split_selector(selector_str) + return self.cache.select(selectors) + + def is_cached(self, selector_str: str) -> bool: + selectors = split_selector(selector_str) + return self.cache.is_cached(selectors) + + +@dataclass +class Flake: + """ + This class represents a flake, and is used to interact with it. + values can be accessed using the select method, which will fetch the value from the cache if it is present. + """ + + identifier: str + cache: FlakeCache = field(default_factory=FlakeCache) + + def __post_init__(self) -> None: + flake_prefetch = run(["nix", "flake", "prefetch", "--json", self.identifier]) + flake_metadata = json.loads(flake_prefetch.stdout) + self.store_path = flake_metadata["storePath"] + self.hash = flake_metadata["hash"] + self.cache = FlakeCache() + + def prepare_cache(self, selectors: list[str]) -> None: + config = nix_config() + nix_code = f""" + let + flake = builtins.getFlake("path:{self.store_path}?narHash={self.hash}"); + in + flake.inputs.nixpkgs.legacyPackages.{config["system"]}.writeText "clan-flake-select" (builtins.toJSON [ ({" ".join([f'flake.clanInternals.lib.select "{attr}" flake' for attr in selectors])}) ]) + """ + build_output = run(nix_build(["--expr", nix_code])).stdout.strip() + outputs = json.loads(Path(build_output).read_text()) + if len(outputs) != len(selectors): + msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}" + raise ClanError(msg) + for i, selector in enumerate(selectors): + self.cache.insert(outputs[i], selector) + + def select(self, selector: str) -> Any: + if not self.cache.is_cached(selector): + log.info(f"Cache miss for {selector}") + print(f"Cache miss for {selector}") + self.prepare_cache([selector]) + return self.cache.select(selector) diff --git a/pkgs/clan-cli/tests/test_flake_caching.py b/pkgs/clan-cli/tests/test_flake_caching.py new file mode 100644 index 000000000..f19df7493 --- /dev/null +++ b/pkgs/clan-cli/tests/test_flake_caching.py @@ -0,0 +1,15 @@ +from clan_cli.flake import FlakeCacheEntry +from fixtures_flakes import ClanFlake + + +def test_flake_caching(test_flake: ClanFlake) -> None: + testdict = {"x": {"y": [123, 345, 456], "z": "bla"}} + test_cache = FlakeCacheEntry(testdict, []) + assert test_cache["x"]["z"].value == "bla" + assert test_cache.is_cached(["x", "z"]) + assert test_cache.select(["x", "y", 0]) == 123 + assert not test_cache.is_cached(["x", "z", 1]) + # TODO check this, but test_flake is not a real clan flake (no clan-core, no clanInternals) + # cmd.run(["nix", "flake", "lock"], cmd.RunOpts(cwd=test_flake.path)) + # flake = Flake(str(test_flake.path)) + # hostnames = flake.select("nixosConfigurations.*.config.networking.hostName")