diff --git a/pkgs/clan-cli/clan_cli/flake.py b/pkgs/clan-cli/clan_cli/flake.py index 057454d60..6298dc12f 100644 --- a/pkgs/clan-cli/clan_cli/flake.py +++ b/pkgs/clan-cli/clan_cli/flake.py @@ -1,10 +1,10 @@ import json import logging -import pickle import re from dataclasses import dataclass from hashlib import sha1 from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Any, cast from clan_cli.cmd import Log, RunOpts, run @@ -329,6 +329,43 @@ class FlakeCacheEntry: msg = f"value is a {type(self.value)}, so cannot subscribe" raise TypeError(msg) + def as_json(self) -> dict[str, Any]: + json_data: Any = {} + if isinstance(self.value, dict): + value = json_data["value"] = {} + for k, v in self.value.items(): + value[k] = v.as_json() + else: # == str | float | None + json_data["value"] = self.value + + if isinstance(self.selector, AllSelector): + json_data["selector"] = "all-selector" + else: # == set[int] | set[str] + json_data["selector"] = list(self.selector) + return json_data + + @staticmethod + def from_json(json_data: dict[str, Any]) -> "FlakeCacheEntry": + raw_selector = json_data.get("selector") + if raw_selector == "all-selector": + selector: Any = AllSelector() + else: # == set[int] | set[str] + assert isinstance(raw_selector, list) + selector = set(raw_selector) + + raw_value = json_data.get("value") + if isinstance(raw_value, dict): + value: Any = {} + for k, v in raw_value.items(): + value[k] = FlakeCacheEntry.from_json(v) + else: # == str | float | None + value = raw_value + + entry = FlakeCacheEntry(None, [], is_out_path=False) + entry.selector = selector + entry.value = value + return entry + def __repr__(self) -> str: if isinstance(self.value, dict): return f"FlakeCache {{{', '.join([str(k) for k in self.value])}}}" @@ -362,14 +399,17 @@ class FlakeCache: def save_to_file(self, path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) - with path.open("wb") as f: - pickle.dump(self.cache, f) + with NamedTemporaryFile(mode="w", dir=path.parent, delete=False) as temp_file: + data = {"cache": self.cache.as_json()} + json.dump(data, temp_file) + temp_file.close() + Path(temp_file.name).rename(path) def load_from_file(self, path: Path) -> None: - if path.exists(): - with path.open("rb") as f: - log.debug(f"Loading cache from {path}") - self.cache = pickle.load(f) + with path.open("r") as f: + log.debug(f"Loading cache from {path}") + data = json.load(f) + self.cache = FlakeCacheEntry.from_json(data["cache"]) @dataclass @@ -418,6 +458,15 @@ class Flake: assert isinstance(self._path, Path) return self._path + def load_cache(self) -> None: + path = self.flake_cache_path + if path is None or self._cache is None or not path.exists(): + return + try: + self._cache.load_from_file(path) + except Exception as e: + log.warning(f"Failed load eval cache: {e}. Continue without cache") + def prefetch(self) -> None: """ Run prefetch to flush the cache as well as initializing it. @@ -443,9 +492,10 @@ class Flake: self._cache = FlakeCache() assert self.hash is not None hashed_hash = sha1(self.hash.encode()).hexdigest() - self.flake_cache_path = Path(user_cache_dir()) / "clan" / "flakes" / hashed_hash - if self.flake_cache_path.exists(): - self._cache.load_from_file(self.flake_cache_path) + self.flake_cache_path = ( + Path(user_cache_dir()) / "clan" / "flakes-v2" / hashed_hash + ) + self.load_cache() if "original" not in flake_metadata: flake_metadata = nix_metadata(self.identifier) @@ -500,11 +550,11 @@ class Flake: if len(outputs) != len(selectors): msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}" raise ClanError(msg) - assert self.flake_cache_path is not None - self._cache.load_from_file(self.flake_cache_path) + self.load_cache() for i, selector in enumerate(selectors): self._cache.insert(outputs[i], selector) - self._cache.save_to_file(self.flake_cache_path) + if self.flake_cache_path: + self._cache.save_to_file(self.flake_cache_path) def select( self,