rework cache to use json instead of pickle
Pickle can silently break if migrate our data layout and also introduces unwanted behaviour such as code injection that we want to avoid.
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from hashlib import sha1
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, cast
|
||||
|
||||
from clan_cli.cmd import Log, RunOpts, run
|
||||
@@ -329,6 +329,43 @@ class FlakeCacheEntry:
|
||||
msg = f"value is a {type(self.value)}, so cannot subscribe"
|
||||
raise TypeError(msg)
|
||||
|
||||
def as_json(self) -> dict[str, Any]:
|
||||
json_data: Any = {}
|
||||
if isinstance(self.value, dict):
|
||||
value = json_data["value"] = {}
|
||||
for k, v in self.value.items():
|
||||
value[k] = v.as_json()
|
||||
else: # == str | float | None
|
||||
json_data["value"] = self.value
|
||||
|
||||
if isinstance(self.selector, AllSelector):
|
||||
json_data["selector"] = "all-selector"
|
||||
else: # == set[int] | set[str]
|
||||
json_data["selector"] = list(self.selector)
|
||||
return json_data
|
||||
|
||||
@staticmethod
|
||||
def from_json(json_data: dict[str, Any]) -> "FlakeCacheEntry":
|
||||
raw_selector = json_data.get("selector")
|
||||
if raw_selector == "all-selector":
|
||||
selector: Any = AllSelector()
|
||||
else: # == set[int] | set[str]
|
||||
assert isinstance(raw_selector, list)
|
||||
selector = set(raw_selector)
|
||||
|
||||
raw_value = json_data.get("value")
|
||||
if isinstance(raw_value, dict):
|
||||
value: Any = {}
|
||||
for k, v in raw_value.items():
|
||||
value[k] = FlakeCacheEntry.from_json(v)
|
||||
else: # == str | float | None
|
||||
value = raw_value
|
||||
|
||||
entry = FlakeCacheEntry(None, [], is_out_path=False)
|
||||
entry.selector = selector
|
||||
entry.value = value
|
||||
return entry
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.value, dict):
|
||||
return f"FlakeCache {{{', '.join([str(k) for k in self.value])}}}"
|
||||
@@ -362,14 +399,17 @@ class FlakeCache:
|
||||
|
||||
def save_to_file(self, path: Path) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("wb") as f:
|
||||
pickle.dump(self.cache, f)
|
||||
with NamedTemporaryFile(mode="w", dir=path.parent, delete=False) as temp_file:
|
||||
data = {"cache": self.cache.as_json()}
|
||||
json.dump(data, temp_file)
|
||||
temp_file.close()
|
||||
Path(temp_file.name).rename(path)
|
||||
|
||||
def load_from_file(self, path: Path) -> None:
|
||||
if path.exists():
|
||||
with path.open("rb") as f:
|
||||
with path.open("r") as f:
|
||||
log.debug(f"Loading cache from {path}")
|
||||
self.cache = pickle.load(f)
|
||||
data = json.load(f)
|
||||
self.cache = FlakeCacheEntry.from_json(data["cache"])
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -418,6 +458,15 @@ class Flake:
|
||||
assert isinstance(self._path, Path)
|
||||
return self._path
|
||||
|
||||
def load_cache(self) -> None:
|
||||
path = self.flake_cache_path
|
||||
if path is None or self._cache is None or not path.exists():
|
||||
return
|
||||
try:
|
||||
self._cache.load_from_file(path)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed load eval cache: {e}. Continue without cache")
|
||||
|
||||
def prefetch(self) -> None:
|
||||
"""
|
||||
Run prefetch to flush the cache as well as initializing it.
|
||||
@@ -443,9 +492,10 @@ class Flake:
|
||||
self._cache = FlakeCache()
|
||||
assert self.hash is not None
|
||||
hashed_hash = sha1(self.hash.encode()).hexdigest()
|
||||
self.flake_cache_path = Path(user_cache_dir()) / "clan" / "flakes" / hashed_hash
|
||||
if self.flake_cache_path.exists():
|
||||
self._cache.load_from_file(self.flake_cache_path)
|
||||
self.flake_cache_path = (
|
||||
Path(user_cache_dir()) / "clan" / "flakes-v2" / hashed_hash
|
||||
)
|
||||
self.load_cache()
|
||||
|
||||
if "original" not in flake_metadata:
|
||||
flake_metadata = nix_metadata(self.identifier)
|
||||
@@ -500,10 +550,10 @@ class Flake:
|
||||
if len(outputs) != len(selectors):
|
||||
msg = f"flake_prepare_cache: Expected {len(outputs)} outputs, got {len(outputs)}"
|
||||
raise ClanError(msg)
|
||||
assert self.flake_cache_path is not None
|
||||
self._cache.load_from_file(self.flake_cache_path)
|
||||
self.load_cache()
|
||||
for i, selector in enumerate(selectors):
|
||||
self._cache.insert(outputs[i], selector)
|
||||
if self.flake_cache_path:
|
||||
self._cache.save_to_file(self.flake_cache_path)
|
||||
|
||||
def select(
|
||||
|
||||
Reference in New Issue
Block a user