vars: optimize generate - reduce cache misses

optimize the `clan vars generate` procedure by pre-caching more selectors.

To achieve this, helper functions are added to several classes.

Also a debugging feature is added to the Flake class in order to track stack traces of cache misses
This commit is contained in:
DavHau
2025-10-02 18:46:11 +07:00
parent b72145d4aa
commit 5a6ffbf916
5 changed files with 150 additions and 36 deletions

View File

@@ -2,6 +2,7 @@ import json
import logging import logging
import shutil import shutil
import subprocess import subprocess
import time
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -28,6 +29,17 @@ from clan_lib.vars.generate import (
) )
def invalidate_flake_cache(flake_path: Path) -> None:
"""Force flake cache invalidation by modifying the git repository.
This adds a dummy file to git which changes the NAR hash of the flake,
forcing a cache invalidation.
"""
dummy_file = flake_path / f".cache_invalidation_{time.time()}"
dummy_file.write_text("invalidate")
run(["git", "add", str(dummy_file)])
def test_dependencies_as_files(temp_dir: Path) -> None: def test_dependencies_as_files(temp_dir: Path) -> None:
decrypted_dependencies = { decrypted_dependencies = {
"gen_1": { "gen_1": {
@@ -1264,37 +1276,71 @@ def test_share_mode_switch_regenerates_secret(
@pytest.mark.with_core @pytest.mark.with_core
def test_cache_misses_for_vars_list( def test_cache_misses_for_vars_operations(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
flake: ClanFlake, flake: ClanFlake,
) -> None: ) -> None:
"""Test that listing vars results in exactly one cache miss.""" """Test that vars operations result in minimal cache misses."""
config = flake.machines["my_machine"] config = flake.machines["my_machine"]
config["nixpkgs"]["hostPlatform"] = "x86_64-linux" config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
# Set up a simple generator # Set up a simple generator with a public value
my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"] my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"]
my_generator["files"]["my_value"]["secret"] = False my_generator["files"]["my_value"]["secret"] = False
my_generator["script"] = 'echo -n "test" > "$out"/my_value' my_generator["script"] = 'echo -n "test_value" > "$out"/my_value'
flake.refresh() flake.refresh()
monkeypatch.chdir(flake.path) monkeypatch.chdir(flake.path)
# # Generate the vars first
# cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
# Create a fresh machine object to ensure clean cache state # Create a fresh machine object to ensure clean cache state
machine = Machine(name="my_machine", flake=Flake(str(flake.path))) machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
# Record initial cache misses # Test 1: Running vars generate with a fresh cache should result in exactly 3 cache misses
initial_cache_misses = machine.flake._cache_misses # Expected cache misses:
# 1. One for getting the list of generators
# 2. One for getting the final script of our test generator (my_generator)
# 3. One for getting the final script of the state version generator (added by default)
# TODO: The third cache miss is undesired in tests. disable state version module for tests
run_generators(
machines=[machine],
generators=None, # Generate all
)
# Print stack traces if we have more than 3 cache misses
if machine.flake._cache_misses != 3:
machine.flake.print_cache_miss_analysis(
title="Cache miss analysis for vars generate"
)
assert machine.flake._cache_misses == 3, (
f"Expected exactly 3 cache misses for vars generate, got {machine.flake._cache_misses}"
)
# Verify the value was generated correctly
var_value = get_machine_var(machine, "my_generator/my_value")
assert var_value.printable_value == "test_value"
# Test 2: List all vars should result in exactly 1 cache miss
# Force cache invalidation (this also resets cache miss tracking)
invalidate_flake_cache(flake.path)
machine.flake.invalidate_cache()
# List all vars - this should result in exactly one cache miss
stringify_all_vars(machine) stringify_all_vars(machine)
assert machine.flake._cache_misses == 1, (
f"Expected exactly 1 cache miss for vars list, got {machine.flake._cache_misses}"
)
# Assert we had exactly one cache miss for the efficient lookup # Test 3: Getting a specific var with a fresh cache should result in exactly 1 cache miss
assert machine.flake._cache_misses == initial_cache_misses + 1, ( # Force cache invalidation (this also resets cache miss tracking)
f"Expected exactly 1 cache miss for vars list, got {machine.flake._cache_misses - initial_cache_misses}" invalidate_flake_cache(flake.path)
machine.flake.invalidate_cache()
var_value = get_machine_var(machine, "my_generator/my_value")
assert var_value.printable_value == "test_value"
assert machine.flake._cache_misses == 1, (
f"Expected exactly 1 cache miss for vars get with fresh cache, got {machine.flake._cache_misses}"
) )

View File

@@ -91,6 +91,36 @@ class Generator:
self self
) and self._public_store.hash_is_valid(self) ) and self._public_store.hash_is_valid(self)
@classmethod
def get_machine_selectors(
cls: type["Generator"],
machine_names: Iterable[str],
) -> list[str]:
"""Get all selectors needed to fetch generators and files for the given machines.
Args:
machine_names: The names of the machines.
Returns:
list[str]: A list of selectors to fetch all generators and files for the machines.
"""
config = nix_config()
system = config["system"]
generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}"
files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}"
all_selectors = []
for machine_name in machine_names:
all_selectors += [
f'clanInternals.machines."{system}"."{machine_name}".{generators_selector}',
f'clanInternals.machines."{system}"."{machine_name}".{files_selector}',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.publicModule',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.secretModule',
]
return all_selectors
@classmethod @classmethod
def get_machine_generators( def get_machine_generators(
cls: type["Generator"], cls: type["Generator"],
@@ -109,22 +139,9 @@ class Generator:
list[Generator]: A list of (unsorted) generators for the machine. list[Generator]: A list of (unsorted) generators for the machine.
""" """
config = nix_config()
system = config["system"]
generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}" generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}"
files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}" files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}"
flake.precache(cls.get_machine_selectors(machine_names))
# precache all machines generators and files to avoid multiple calls to nix
all_selectors = []
for machine_name in machine_names:
all_selectors += [
f'clanInternals.machines."{system}"."{machine_name}".{generators_selector}',
f'clanInternals.machines."{system}"."{machine_name}".{files_selector}',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.publicModule',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.secretModule',
]
flake.precache(all_selectors)
generators = [] generators = []

View File

@@ -3,6 +3,7 @@ import logging
import os import os
import re import re
import shlex import shlex
import traceback
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from functools import cache from functools import cache
@@ -792,7 +793,7 @@ class Flake:
_cache: FlakeCache | None = field(init=False, default=None) _cache: FlakeCache | None = field(init=False, default=None)
_path: Path | None = field(init=False, default=None) _path: Path | None = field(init=False, default=None)
_is_local: bool | None = field(init=False, default=None) _is_local: bool | None = field(init=False, default=None)
_cache_misses: int = field(init=False, default=0) _cache_miss_stack_traces: list[str] = field(init=False, default_factory=list)
@classmethod @classmethod
def from_json( def from_json(
@@ -814,6 +815,34 @@ class Flake:
return NotImplemented return NotImplemented
return self.identifier == other.identifier return self.identifier == other.identifier
def _record_cache_miss(self, selector_info: str) -> None:
"""Record a cache miss with its stack trace."""
stack_trace = "".join(traceback.format_stack())
self._cache_miss_stack_traces.append(f"{selector_info}\n{stack_trace}")
@property
def _cache_misses(self) -> int:
"""Get the count of cache misses from the stack trace list."""
return len(self._cache_miss_stack_traces)
def print_cache_miss_analysis(self, title: str = "Cache miss analysis") -> None:
"""Print detailed analysis of cache misses with stack traces.
Args:
title: Title for the analysis output
"""
if not self._cache_miss_stack_traces:
return
print(f"\n=== {title} ===")
print(f"Total cache misses: {len(self._cache_miss_stack_traces)}")
print("\nStack traces for all cache misses:")
for i, trace in enumerate(self._cache_miss_stack_traces, 1):
print(f"\n--- Cache miss #{i} ---")
print(trace)
print("=" * 50)
@property @property
def is_local(self) -> bool: def is_local(self) -> bool:
if self._is_local is None: if self._is_local is None:
@@ -886,10 +915,13 @@ class Flake:
"""Invalidate the cache and reload it. """Invalidate the cache and reload it.
This method is used to refresh the cache by reloading it from the flake. This method is used to refresh the cache by reloading it from the flake.
Also resets cache miss tracking.
""" """
self.prefetch() self.prefetch()
self._cache = FlakeCache() self._cache = FlakeCache()
# Reset cache miss tracking when invalidating cache
self._cache_miss_stack_traces.clear()
if self.hash is None: if self.hash is None:
msg = "Hash cannot be None" msg = "Hash cannot be None"
raise ClanError(msg) raise ClanError(msg)
@@ -1063,8 +1095,10 @@ class Flake:
] ]
if not_fetched_selectors: if not_fetched_selectors:
# Increment cache miss counter for each selector that wasn't cached # Record cache miss with stack trace
self._cache_misses += 1 self._record_cache_miss(
f"Cache miss for selectors: {not_fetched_selectors}"
)
self.get_from_nix(not_fetched_selectors) self.get_from_nix(not_fetched_selectors)
def select( def select(
@@ -1090,7 +1124,8 @@ class Flake:
if not self._cache.is_cached(selector): if not self._cache.is_cached(selector):
log.debug(f"(cached) $ clan select {shlex.quote(selector)}") log.debug(f"(cached) $ clan select {shlex.quote(selector)}")
log.debug(f"Cache miss for {selector}") log.debug(f"Cache miss for {selector}")
self._cache_misses += 1 # Record cache miss with stack trace
self._record_cache_miss(f"Cache miss for selector: {selector}")
self.get_from_nix([selector]) self.get_from_nix([selector])
else: else:
log.debug(f"$ clan select {shlex.quote(selector)}") log.debug(f"$ clan select {shlex.quote(selector)}")

View File

@@ -129,10 +129,21 @@ class InventoryStore:
self._allowed_path_transforms = _allowed_path_transforms self._allowed_path_transforms = _allowed_path_transforms
if _keys is None: if _keys is None:
_keys = list(InventorySnapshot.__annotations__.keys()) _keys = self.default_keys()
self._keys = _keys self._keys = _keys
@classmethod
def default_keys(cls) -> list[str]:
return list(InventorySnapshot.__annotations__.keys())
@classmethod
def default_selectors(cls) -> list[str]:
return [
f"clanInternals.inventoryClass.inventory.{key}"
for key in cls.default_keys()
]
def _load_merged_inventory(self) -> InventorySnapshot: def _load_merged_inventory(self) -> InventorySnapshot:
"""Loads the evaluated inventory. """Loads the evaluated inventory.
After all merge operations with eventual nix code in buildClan. After all merge operations with eventual nix code in buildClan.

View File

@@ -9,6 +9,7 @@ from clan_cli.vars.migration import check_can_migrate, migrate_files
from clan_lib.api import API from clan_lib.api import API
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine from clan_lib.machines.machines import Machine
from clan_lib.persist.inventory_store import InventoryStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -37,18 +38,22 @@ def get_generators(
if not machines: if not machines:
msg = "At least one machine must be provided" msg = "At least one machine must be provided"
raise ClanError(msg) raise ClanError(msg)
flake = machines[0].flake
all_machines = machines[0].flake.list_machines().keys() flake.precache(
InventoryStore.default_selectors()
+ Generator.get_machine_selectors(m.name for m in machines)
)
all_machines = flake.list_machines().keys()
requested_machines = [machine.name for machine in machines] requested_machines = [machine.name for machine in machines]
all_generators_list = Generator.get_machine_generators( all_generators_list = Generator.get_machine_generators(
all_machines, all_machines,
machines[0].flake, flake,
include_previous_values=include_previous_values, include_previous_values=include_previous_values,
) )
requested_generators_list = Generator.get_machine_generators( requested_generators_list = Generator.get_machine_generators(
requested_machines, requested_machines,
machines[0].flake, flake,
include_previous_values=include_previous_values, include_previous_values=include_previous_values,
) )