diff --git a/pkgs/clan-cli/clan_cli/flake.py b/pkgs/clan-cli/clan_cli/flake.py index 28658f926..97b4f6215 100644 --- a/pkgs/clan-cli/clan_cli/flake.py +++ b/pkgs/clan-cli/clan_cli/flake.py @@ -1,11 +1,13 @@ import json import logging +import pickle import re from dataclasses import dataclass from pathlib import Path from typing import Any from clan_cli.cmd import run +from clan_cli.dirs import user_cache_dir from clan_cli.errors import ClanError from clan_cli.nix import nix_build, nix_command, nix_config, nix_test_store @@ -290,6 +292,16 @@ class FlakeCache: selectors = split_selector(selector_str) return self.cache.is_cached(selectors) + def save_to_file(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as f: + pickle.dump(self.cache, f) + + def load_from_file(self, path: Path) -> None: + if path.exists(): + with path.open("rb") as f: + self.cache = pickle.load(f) + @dataclass class Flake: @@ -335,7 +347,6 @@ class Flake: return self._path def prefetch(self) -> None: - self._cache = FlakeCache() flake_prefetch = run( nix_command( [ @@ -352,6 +363,12 @@ class Flake: flake_metadata = json.loads(flake_prefetch.stdout) self.store_path = flake_metadata["storePath"] self.hash = flake_metadata["hash"] + + self._cache = FlakeCache() + self.flake_cache_path = Path(user_cache_dir()) / "clan" / "flakes" / self.hash + if self.flake_cache_path.exists(): + self._cache.load_from_file(self.flake_cache_path) + if flake_metadata["original"].get("url", "").startswith("file:"): self._is_local = True path = flake_metadata["original"]["url"].removeprefix("file://") @@ -364,7 +381,7 @@ class Flake: self._is_local = False self._path = Path(self.store_path) - def prepare_cache(self, selectors: list[str]) -> None: + def get_from_nix(self, selectors: list[str]) -> None: if self._cache is None: self.prefetch() assert self._cache is not None @@ -383,15 +400,18 @@ class Flake: if len(outputs) != len(selectors): msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}" raise ClanError(msg) + self._cache.load_from_file(self.flake_cache_path) for i, selector in enumerate(selectors): self._cache.insert(outputs[i], selector) + self._cache.save_to_file(self.flake_cache_path) def select(self, selector: str) -> Any: if self._cache is None: self.prefetch() assert self._cache is not None + self._cache.load_from_file(self.flake_cache_path) if not self._cache.is_cached(selector): log.info(f"Cache miss for {selector}") - self.prepare_cache([selector]) + self.get_from_nix([selector]) return self._cache.select(selector) diff --git a/pkgs/clan-cli/tests/test_flake_caching.py b/pkgs/clan-cli/tests/test_flake_caching.py index b04dfffac..69f0223e1 100644 --- a/pkgs/clan-cli/tests/test_flake_caching.py +++ b/pkgs/clan-cli/tests/test_flake_caching.py @@ -1,5 +1,5 @@ import pytest -from clan_cli.flake import Flake, FlakeCacheEntry +from clan_cli.flake import Flake, FlakeCache, FlakeCacheEntry from fixtures_flakes import ClanFlake @@ -8,6 +8,7 @@ def test_select() -> None: test_cache = FlakeCacheEntry(testdict, []) assert test_cache["x"]["z"].value == "bla" assert test_cache.is_cached(["x", "z"]) + assert not test_cache.is_cached(["x", "y", "z"]) assert test_cache.select(["x", "y", 0]) == 123 assert not test_cache.is_cached(["x", "z", 1]) @@ -34,3 +35,23 @@ def test_flake_caching(flake: ClanFlake) -> None: "machine2": "machine2", "machine3": "machine3", } + + +@pytest.mark.with_core +def test_cache_persistance(flake: ClanFlake) -> None: + m1 = flake.machines["machine1"] + m1["nixpkgs"]["hostPlatform"] = "x86_64-linux" + flake.refresh() + + flake1 = Flake(str(flake.path)) + flake2 = Flake(str(flake.path)) + flake1.prefetch() + flake2.prefetch() + assert isinstance(flake1._cache, FlakeCache) # noqa: SLF001 + assert isinstance(flake2._cache, FlakeCache) # noqa: SLF001 + assert not flake1._cache.is_cached( # noqa: SLF001 + "nixosConfigurations.*.config.networking.hostName" + ) + flake1.select("nixosConfigurations.*.config.networking.hostName") + flake2.prefetch() + assert flake2._cache.is_cached("nixosConfigurations.*.config.networking.hostName") # noqa: SLF001