Merge pull request 'Rework cache to use json instead of pickle' (#3319) from validation-hash-2 into main

Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/3319
This commit is contained in:
Mic92
2025-04-15 07:11:03 +00:00
2 changed files with 65 additions and 13 deletions

View File

@@ -1,10 +1,10 @@
import json
import logging
import pickle
import re
from dataclasses import dataclass
from hashlib import sha1
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, cast
from clan_cli.cmd import Log, RunOpts, run
@@ -329,6 +329,43 @@ class FlakeCacheEntry:
msg = f"value is a {type(self.value)}, so cannot subscribe"
raise TypeError(msg)
def as_json(self) -> dict[str, Any]:
json_data: Any = {}
if isinstance(self.value, dict):
value = json_data["value"] = {}
for k, v in self.value.items():
value[k] = v.as_json()
else: # == str | float | None
json_data["value"] = self.value
if isinstance(self.selector, AllSelector):
json_data["selector"] = "all-selector"
else: # == set[int] | set[str]
json_data["selector"] = list(self.selector)
return json_data
@staticmethod
def from_json(json_data: dict[str, Any]) -> "FlakeCacheEntry":
raw_selector = json_data.get("selector")
if raw_selector == "all-selector":
selector: Any = AllSelector()
else: # == set[int] | set[str]
assert isinstance(raw_selector, list)
selector = set(raw_selector)
raw_value = json_data.get("value")
if isinstance(raw_value, dict):
value: Any = {}
for k, v in raw_value.items():
value[k] = FlakeCacheEntry.from_json(v)
else: # == str | float | None
value = raw_value
entry = FlakeCacheEntry(None, [], is_out_path=False)
entry.selector = selector
entry.value = value
return entry
def __repr__(self) -> str:
if isinstance(self.value, dict):
return f"FlakeCache {{{', '.join([str(k) for k in self.value])}}}"
@@ -362,14 +399,17 @@ class FlakeCache:
def save_to_file(self, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("wb") as f:
pickle.dump(self.cache, f)
with NamedTemporaryFile(mode="w", dir=path.parent, delete=False) as temp_file:
data = {"cache": self.cache.as_json()}
json.dump(data, temp_file)
temp_file.close()
Path(temp_file.name).rename(path)
def load_from_file(self, path: Path) -> None:
if path.exists():
with path.open("rb") as f:
log.debug(f"Loading cache from {path}")
self.cache = pickle.load(f)
with path.open("r") as f:
log.debug(f"Loading cache from {path}")
data = json.load(f)
self.cache = FlakeCacheEntry.from_json(data["cache"])
@dataclass
@@ -418,6 +458,15 @@ class Flake:
assert isinstance(self._path, Path)
return self._path
def load_cache(self) -> None:
path = self.flake_cache_path
if path is None or self._cache is None or not path.exists():
return
try:
self._cache.load_from_file(path)
except Exception as e:
log.warning(f"Failed load eval cache: {e}. Continue without cache")
def prefetch(self) -> None:
"""
Run prefetch to flush the cache as well as initializing it.
@@ -443,9 +492,10 @@ class Flake:
self._cache = FlakeCache()
assert self.hash is not None
hashed_hash = sha1(self.hash.encode()).hexdigest()
self.flake_cache_path = Path(user_cache_dir()) / "clan" / "flakes" / hashed_hash
if self.flake_cache_path.exists():
self._cache.load_from_file(self.flake_cache_path)
self.flake_cache_path = (
Path(user_cache_dir()) / "clan" / "flakes-v2" / hashed_hash
)
self.load_cache()
if "original" not in flake_metadata:
flake_metadata = nix_metadata(self.identifier)
@@ -500,11 +550,11 @@ class Flake:
if len(outputs) != len(selectors):
msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}"
raise ClanError(msg)
assert self.flake_cache_path is not None
self._cache.load_from_file(self.flake_cache_path)
self.load_cache()
for i, selector in enumerate(selectors):
self._cache.insert(outputs[i], selector)
self._cache.save_to_file(self.flake_cache_path)
if self.flake_cache_path:
self._cache.save_to_file(self.flake_cache_path)
def select(
self,

View File

@@ -1,5 +1,6 @@
import json
import subprocess
import sys
from contextlib import ExitStack
import pytest
@@ -15,6 +16,7 @@ from clan_cli.vms.run import inspect_vm, spawn_vm
@pytest.mark.impure
@pytest.mark.skipif(sys.platform == "darwin", reason="preload doesn't work on darwin")
def test_vm_deployment(
flake: ClanFlake,
nix_config: dict[str, ConfigItem],