import json import logging import pickle import re from dataclasses import dataclass from hashlib import sha1 from pathlib import Path from typing import Any, cast from clan_cli.cmd import run from clan_cli.dirs import user_cache_dir from clan_cli.errors import ClanError from clan_cli.nix import nix_build, nix_command, nix_config, nix_test_store log = logging.getLogger(__name__) class AllSelector: pass Selector = str | int | AllSelector | set[int] | set[str] def split_selector(selector: str) -> list[Selector]: """ takes a string and returns a list of selectors. a selector can be: - a string, which is a key in a dict - an integer, which is an index in a list - a set of strings, which are keys in a dict - a set of integers, which are indices in a list - a quoted string, which is a key in a dict - the string "*", which selects all elements in a list or dict """ pattern = r'"[^"]*"|[^.]+' matches = re.findall(pattern, selector) # Extract the matched groups (either quoted or unquoted parts) selectors: list[Selector] = [] for selector in matches: if selector == "*": selectors.append(AllSelector()) elif selector.isdigit(): selectors.append({int(selector)}) elif selector.startswith("{") and selector.endswith("}"): sub_selectors = set(selector[1:-1].split(",")) selectors.append(sub_selectors) elif selector.startswith('"') and selector.endswith('"'): selectors.append(selector[1:-1]) else: selectors.append(selector) return selectors @dataclass class FlakeCacheEntry: """ a recursive structure to store the cache, with a value and a selector """ def __init__( self, value: str | float | dict[str, Any] | list[Any] | None, selectors: list[Selector], is_out_path: bool = False, ) -> None: self.value: str | float | int | None | dict[str | int, FlakeCacheEntry] self.selector: set[int] | set[str] | AllSelector selector: Selector = AllSelector() if selectors == []: self.selector = AllSelector() elif isinstance(selectors[0], set): self.selector = selectors[0] selector = selectors[0] elif isinstance(selectors[0], int): self.selector = {int(selectors[0])} selector = int(selectors[0]) elif isinstance(selectors[0], str): self.selector = {(selectors[0])} selector = selectors[0] elif isinstance(selectors[0], AllSelector): self.selector = AllSelector() if is_out_path: if selectors != []: msg = "Cannot index outPath" raise ValueError(msg) if not isinstance(value, str): msg = "outPath must be a string" raise ValueError(msg) self.value = value elif isinstance(selector, str): self.value = {selector: FlakeCacheEntry(value, selectors[1:])} elif isinstance(value, dict): if isinstance(self.selector, set): if not all(isinstance(v, str) for v in self.selector): msg = "Cannot index dict with non-str set" raise ValueError(msg) self.value = {} for key, value_ in value.items(): if key == "outPath": self.value[key] = FlakeCacheEntry( value_, selectors[1:], is_out_path=True ) else: self.value[key] = FlakeCacheEntry(value_, selectors[1:]) elif isinstance(value, list): if isinstance(selector, int): if len(value) != 1: msg = "Cannot index list with int selector when value is not singleton" raise ValueError(msg) self.value = { int(selector): FlakeCacheEntry(value[0], selectors[1:]), } if isinstance(selector, set): if all(isinstance(v, int) for v in selector): self.value = {} for i, v in enumerate([selector]): assert isinstance(v, int) self.value[int(v)] = FlakeCacheEntry(value[i], selectors[1:]) else: msg = "Cannot index list with non-int set" raise ValueError(msg) elif isinstance(self.selector, AllSelector): self.value = {} for i, v in enumerate(value): if isinstance(v, dict | list | str | float | int): self.value[i] = FlakeCacheEntry(v, selectors[1:]) else: msg = f"expected integer selector or all for type list, but got {type(selector)}" raise TypeError(msg) elif isinstance(value, str) and value.startswith("/nix/store/"): self.value = {} self.selector = self.selector = {"outPath"} self.value["outPath"] = FlakeCacheEntry( value, selectors[1:], is_out_path=True ) elif isinstance(value, (str | float | int | None)): self.value = value def insert( self, value: str | float | dict[str, Any] | list[Any], selectors: list[Selector] ) -> None: selector: Selector if selectors == []: selector = AllSelector() else: selector = selectors[0] if isinstance(selector, str): if isinstance(self.value, dict): if selector in self.value: self.value[selector].insert(value, selectors[1:]) else: self.value[selector] = FlakeCacheEntry(value, selectors[1:]) return msg = f"Cannot insert {selector} into non dict value" raise TypeError(msg) if isinstance(selector, AllSelector): self.selector = AllSelector() elif isinstance(self.selector, set) and isinstance(selector, set): if all(isinstance(v, str) for v in self.selector) and all( isinstance(v, str) for v in selector ): selector = cast(set[str], selector) self.selector = cast(set[str], self.selector) self.selector = self.selector.union(selector) elif all(isinstance(v, int) for v in self.selector) and all( isinstance(v, int) for v in selector ): selector = cast(set[int], selector) self.selector = cast(set[int], self.selector) self.selector = self.selector.union(selector) else: msg = "Cannot union set of different types" raise ValueError(msg) elif isinstance(self.selector, set) and isinstance(selector, int): if all(isinstance(v, int) for v in self.selector): self.selector = cast(set[int], self.selector) self.selector.add(selector) elif isinstance(self.selector, set) and isinstance(selector, str): if all(isinstance(v, str) for v in self.selector): self.selector = cast(set[str], self.selector) self.selector.add(selector) else: msg = f"Cannot insert {selector} into {self.selector}" raise TypeError(msg) if isinstance(self.value, dict) and isinstance(value, dict): for key, value_ in value.items(): if key in self.value: self.value[key].insert(value_, selectors[1:]) else: self.value[key] = FlakeCacheEntry(value_, selectors[1:]) elif isinstance(self.value, dict) and isinstance(value, list): if isinstance(selector, set): if not all(isinstance(v, int) for v in selector): msg = "Cannot list with non-int set" raise ValueError(msg) for realindex, requested_index in enumerate(selector): assert isinstance(requested_index, int) if requested_index in self.value: self.value[requested_index].insert( value[realindex], selectors[1:] ) elif isinstance(selector, AllSelector): for index, v in enumerate(value): if index in self.value: self.value[index].insert(v, selectors[1:]) else: self.value[index] = FlakeCacheEntry(v, selectors[1:]) elif isinstance(selector, int): if selector in self.value: self.value[selector].insert(value[0], selectors[1:]) else: self.value[selector] = FlakeCacheEntry(value[0], selectors[1:]) elif isinstance(value, str) and value.startswith("/nix/store/"): self.value = {} self.value["outPath"] = FlakeCacheEntry( value, selectors[1:], is_out_path=True ) elif isinstance(value, (str | float | int)): if self.value: if self.value != value: msg = "value mismatch in cache, something is fishy" raise TypeError(msg) else: msg = f"Cannot insert value of type {type(value)} into cache" raise TypeError(msg) def is_cached(self, selectors: list[Selector]) -> bool: selector: Selector if selectors == []: selector = AllSelector() else: selector = selectors[0] if isinstance(self.value, str | float | int | None): return selectors == [] if isinstance(selector, AllSelector): if isinstance(self.selector, AllSelector): return all( self.value[sel].is_cached(selectors[1:]) for sel in self.value ) # TODO: check if we already have all the keys anyway? return False if ( isinstance(selector, set) and isinstance(self.selector, set) and isinstance(self.value, dict) ): if not selector.issubset(self.selector): return False return all(self.value[sel].is_cached(selectors[1:]) for sel in selector) if isinstance(selector, str | int) and isinstance(self.value, dict): if selector in self.value: return self.value[selector].is_cached(selectors[1:]) return False return False def select(self, selectors: list[Selector]) -> Any: selector: Selector if selectors == []: selector = AllSelector() else: selector = selectors[0] if selectors == [] and isinstance(self.value, dict) and "outPath" in self.value: return self.value["outPath"].value if isinstance(self.value, str | float | int | None): return self.value if isinstance(self.value, dict): if isinstance(selector, AllSelector): return {k: v.select(selectors[1:]) for k, v in self.value.items()} if isinstance(selector, set): return { k: v.select(selectors[1:]) for k, v in self.value.items() if k in selector } if isinstance(selector, str | int): return self.value[selector].select(selectors[1:]) msg = f"Cannot select {selector} from type {type(self.value)}" raise TypeError(msg) def __getitem__(self, name: str) -> "FlakeCacheEntry": if isinstance(self.value, dict): return self.value[name] msg = f"value is a {type(self.value)}, so cannot subscribe" raise TypeError(msg) def __repr__(self) -> str: if isinstance(self.value, dict): return f"FlakeCache {{{', '.join([str(k) for k in self.value])}}}" return f"FlakeCache {self.value}" class FlakeCache: """ an in-memory cache for flake outputs, uses a recursive FLakeCacheEntry structure """ def __init__(self) -> None: self.cache: FlakeCacheEntry = FlakeCacheEntry({}, []) def insert(self, data: dict[str, Any], selector_str: str) -> None: if selector_str: selectors = split_selector(selector_str) else: selectors = [] self.cache.insert(data, selectors) def select(self, selector_str: str) -> Any: selectors = split_selector(selector_str) return self.cache.select(selectors) def is_cached(self, selector_str: str) -> bool: selectors = split_selector(selector_str) return self.cache.is_cached(selectors) def save_to_file(self, path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("wb") as f: pickle.dump(self.cache, f) def load_from_file(self, path: Path) -> None: if path.exists(): with path.open("rb") as f: self.cache = pickle.load(f) @dataclass class Flake: """ This class represents a flake, and is used to interact with it. values can be accessed using the select method, which will fetch the value from the cache if it is present. """ identifier: str def __post_init__(self) -> None: self._cache: FlakeCache | None = None self._path: Path | None = None self._is_local: bool | None = None @classmethod def from_json(cls: type["Flake"], data: dict[str, Any]) -> "Flake": return cls(data["identifier"]) def __str__(self) -> str: return self.identifier def __hash__(self) -> int: return hash(self.identifier) def __eq__(self, other: object) -> bool: if not isinstance(other, Flake): return NotImplemented return self.identifier == other.identifier @property def is_local(self) -> bool: if self._is_local is None: self.prefetch() assert isinstance(self._is_local, bool) return self._is_local @property def path(self) -> Path: if self._path is None: self.prefetch() assert isinstance(self._path, Path) return self._path def prefetch(self) -> None: """ Run prefetch to flush the cache as well as initializing it. """ flake_prefetch = run( nix_command( [ "flake", "prefetch", "--json", "--option", "flake-registry", "", self.identifier, ] ) ) flake_metadata = json.loads(flake_prefetch.stdout) self.store_path = flake_metadata["storePath"] self.hash = flake_metadata["hash"] self._cache = FlakeCache() hashed_hash = sha1(self.hash.encode()).hexdigest() self.flake_cache_path = Path(user_cache_dir()) / "clan" / "flakes" / hashed_hash if self.flake_cache_path.exists(): self._cache.load_from_file(self.flake_cache_path) if flake_metadata["original"].get("url", "").startswith("file:"): self._is_local = True path = flake_metadata["original"]["url"].removeprefix("file://") path = path.removeprefix("file:") self._path = Path(path) elif flake_metadata["original"].get("path"): self._is_local = True self._path = Path(flake_metadata["original"]["path"]) else: self._is_local = False self._path = Path(self.store_path) def get_from_nix( self, selectors: list[str], nix_options: list[str] | None = None, ) -> None: if self._cache is None: self.prefetch() assert self._cache is not None if nix_options is None: nix_options = [] config = nix_config() nix_code = f""" let flake = builtins.getFlake("path:{self.store_path}?narHash={self.hash}"); in flake.inputs.nixpkgs.legacyPackages.{config["system"]}.writeText "clan-flake-select" ( builtins.toJSON [ ({" ".join([f"flake.clanInternals.lib.select ''{attr}'' flake" for attr in selectors])}) ] ) """ if tmp_store := nix_test_store(): nix_options += ["--store", str(tmp_store)] nix_options.append("--impure") build_output = Path( run(nix_build(["--expr", nix_code, *nix_options])).stdout.strip() ) if tmp_store: build_output = tmp_store.joinpath(*build_output.parts[1:]) outputs = json.loads(build_output.read_text()) if len(outputs) != len(selectors): msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}" raise ClanError(msg) self._cache.load_from_file(self.flake_cache_path) for i, selector in enumerate(selectors): self._cache.insert(outputs[i], selector) self._cache.save_to_file(self.flake_cache_path) def select( self, selector: str, nix_options: list[str] | None = None, ) -> Any: if self._cache is None: self.prefetch() assert self._cache is not None self._cache.load_from_file(self.flake_cache_path) if not self._cache.is_cached(selector): log.info(f"Cache miss for {selector}") self.get_from_nix([selector], nix_options) value = self._cache.select(selector) num_loops = 0 # Check if all nix store paths exist # FIXME: If the value is a highly nested structure, this will be slow def recursive_path_exists_check(val: Any) -> Any: nonlocal num_loops if isinstance(val, str): if val.startswith("/nix/store/"): path = Path(val) if not path.exists(): msg = f"{path} does not exist" raise ClanError(msg) elif isinstance(val, dict): return {k: recursive_path_exists_check(v) for k, v in val.items()} elif isinstance(val, list): return [recursive_path_exists_check(v) for v in val] if num_loops > 75: msg = "Maximum recursion depth (75) exceeded while checking paths" log.warning(msg) raise ClanError(msg) num_loops += 1 return val # If there are any paths that don't exist, refresh from nix try: recursive_path_exists_check(value) except ClanError as e: log.info(f"Path {e} for {selector} does not exist, refreshing from nix") self.get_from_nix([selector], nix_options) value = self._cache.select(selector) return value