rework cache to use json instead of pickle

Pickle can silently break if migrate our data layout and also introduces
unwanted behaviour such as code injection that we want to avoid.
This commit is contained in:
Jörg Thalheim
2025-04-15 06:16:41 +00:00
parent 833798f650
commit 949536bb2b

View File

@@ -1,10 +1,10 @@
import json
import logging
import pickle
import re
from dataclasses import dataclass
from hashlib import sha1
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, cast
from clan_cli.cmd import Log, RunOpts, run
@@ -329,6 +329,43 @@ class FlakeCacheEntry:
msg = f"value is a {type(self.value)}, so cannot subscribe"
raise TypeError(msg)
def as_json(self) -> dict[str, Any]:
json_data: Any = {}
if isinstance(self.value, dict):
value = json_data["value"] = {}
for k, v in self.value.items():
value[k] = v.as_json()
else: # == str | float | None
json_data["value"] = self.value
if isinstance(self.selector, AllSelector):
json_data["selector"] = "all-selector"
else: # == set[int] | set[str]
json_data["selector"] = list(self.selector)
return json_data
@staticmethod
def from_json(json_data: dict[str, Any]) -> "FlakeCacheEntry":
raw_selector = json_data.get("selector")
if raw_selector == "all-selector":
selector: Any = AllSelector()
else: # == set[int] | set[str]
assert isinstance(raw_selector, list)
selector = set(raw_selector)
raw_value = json_data.get("value")
if isinstance(raw_value, dict):
value: Any = {}
for k, v in raw_value.items():
value[k] = FlakeCacheEntry.from_json(v)
else: # == str | float | None
value = raw_value
entry = FlakeCacheEntry(None, [], is_out_path=False)
entry.selector = selector
entry.value = value
return entry
def __repr__(self) -> str:
if isinstance(self.value, dict):
return f"FlakeCache {{{', '.join([str(k) for k in self.value])}}}"
@@ -362,14 +399,17 @@ class FlakeCache:
def save_to_file(self, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("wb") as f:
pickle.dump(self.cache, f)
with NamedTemporaryFile(mode="w", dir=path.parent, delete=False) as temp_file:
data = {"cache": self.cache.as_json()}
json.dump(data, temp_file)
temp_file.close()
Path(temp_file.name).rename(path)
def load_from_file(self, path: Path) -> None:
if path.exists():
with path.open("rb") as f:
with path.open("r") as f:
log.debug(f"Loading cache from {path}")
self.cache = pickle.load(f)
data = json.load(f)
self.cache = FlakeCacheEntry.from_json(data["cache"])
@dataclass
@@ -418,6 +458,15 @@ class Flake:
assert isinstance(self._path, Path)
return self._path
def load_cache(self) -> None:
path = self.flake_cache_path
if path is None or self._cache is None or not path.exists():
return
try:
self._cache.load_from_file(path)
except Exception as e:
log.warning(f"Failed load eval cache: {e}. Continue without cache")
def prefetch(self) -> None:
"""
Run prefetch to flush the cache as well as initializing it.
@@ -443,9 +492,10 @@ class Flake:
self._cache = FlakeCache()
assert self.hash is not None
hashed_hash = sha1(self.hash.encode()).hexdigest()
self.flake_cache_path = Path(user_cache_dir()) / "clan" / "flakes" / hashed_hash
if self.flake_cache_path.exists():
self._cache.load_from_file(self.flake_cache_path)
self.flake_cache_path = (
Path(user_cache_dir()) / "clan" / "flakes-v2" / hashed_hash
)
self.load_cache()
if "original" not in flake_metadata:
flake_metadata = nix_metadata(self.identifier)
@@ -500,10 +550,10 @@ class Flake:
if len(outputs) != len(selectors):
msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}"
raise ClanError(msg)
assert self.flake_cache_path is not None
self._cache.load_from_file(self.flake_cache_path)
self.load_cache()
for i, selector in enumerate(selectors):
self._cache.insert(outputs[i], selector)
if self.flake_cache_path:
self._cache.save_to_file(self.flake_cache_path)
def select(