Merge pull request 'vars: optimize generate - reduce cache misses' (#5348) from dave into main

Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/5348
This commit is contained in:
DavHau
2025-10-02 11:50:26 +00:00
5 changed files with 150 additions and 36 deletions

View File

@@ -2,6 +2,7 @@ import json
import logging
import shutil
import subprocess
import time
from pathlib import Path
import pytest
@@ -28,6 +29,17 @@ from clan_lib.vars.generate import (
)
def invalidate_flake_cache(flake_path: Path) -> None:
"""Force flake cache invalidation by modifying the git repository.
This adds a dummy file to git which changes the NAR hash of the flake,
forcing a cache invalidation.
"""
dummy_file = flake_path / f".cache_invalidation_{time.time()}"
dummy_file.write_text("invalidate")
run(["git", "add", str(dummy_file)])
def test_dependencies_as_files(temp_dir: Path) -> None:
decrypted_dependencies = {
"gen_1": {
@@ -1264,37 +1276,71 @@ def test_share_mode_switch_regenerates_secret(
@pytest.mark.with_core
def test_cache_misses_for_vars_list(
def test_cache_misses_for_vars_operations(
monkeypatch: pytest.MonkeyPatch,
flake: ClanFlake,
) -> None:
"""Test that listing vars results in exactly one cache miss."""
"""Test that vars operations result in minimal cache misses."""
config = flake.machines["my_machine"]
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
# Set up a simple generator
# Set up a simple generator with a public value
my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"]
my_generator["files"]["my_value"]["secret"] = False
my_generator["script"] = 'echo -n "test" > "$out"/my_value'
my_generator["script"] = 'echo -n "test_value" > "$out"/my_value'
flake.refresh()
monkeypatch.chdir(flake.path)
# # Generate the vars first
# cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
# Create a fresh machine object to ensure clean cache state
machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
# Record initial cache misses
initial_cache_misses = machine.flake._cache_misses
# Test 1: Running vars generate with a fresh cache should result in exactly 3 cache misses
# Expected cache misses:
# 1. One for getting the list of generators
# 2. One for getting the final script of our test generator (my_generator)
# 3. One for getting the final script of the state version generator (added by default)
# TODO: The third cache miss is undesired in tests. disable state version module for tests
run_generators(
machines=[machine],
generators=None, # Generate all
)
# Print stack traces if we have more than 3 cache misses
if machine.flake._cache_misses != 3:
machine.flake.print_cache_miss_analysis(
title="Cache miss analysis for vars generate"
)
assert machine.flake._cache_misses == 3, (
f"Expected exactly 3 cache misses for vars generate, got {machine.flake._cache_misses}"
)
# Verify the value was generated correctly
var_value = get_machine_var(machine, "my_generator/my_value")
assert var_value.printable_value == "test_value"
# Test 2: List all vars should result in exactly 1 cache miss
# Force cache invalidation (this also resets cache miss tracking)
invalidate_flake_cache(flake.path)
machine.flake.invalidate_cache()
# List all vars - this should result in exactly one cache miss
stringify_all_vars(machine)
assert machine.flake._cache_misses == 1, (
f"Expected exactly 1 cache miss for vars list, got {machine.flake._cache_misses}"
)
# Assert we had exactly one cache miss for the efficient lookup
assert machine.flake._cache_misses == initial_cache_misses + 1, (
f"Expected exactly 1 cache miss for vars list, got {machine.flake._cache_misses - initial_cache_misses}"
# Test 3: Getting a specific var with a fresh cache should result in exactly 1 cache miss
# Force cache invalidation (this also resets cache miss tracking)
invalidate_flake_cache(flake.path)
machine.flake.invalidate_cache()
var_value = get_machine_var(machine, "my_generator/my_value")
assert var_value.printable_value == "test_value"
assert machine.flake._cache_misses == 1, (
f"Expected exactly 1 cache miss for vars get with fresh cache, got {machine.flake._cache_misses}"
)

View File

@@ -91,6 +91,36 @@ class Generator:
self
) and self._public_store.hash_is_valid(self)
@classmethod
def get_machine_selectors(
cls: type["Generator"],
machine_names: Iterable[str],
) -> list[str]:
"""Get all selectors needed to fetch generators and files for the given machines.
Args:
machine_names: The names of the machines.
Returns:
list[str]: A list of selectors to fetch all generators and files for the machines.
"""
config = nix_config()
system = config["system"]
generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}"
files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}"
all_selectors = []
for machine_name in machine_names:
all_selectors += [
f'clanInternals.machines."{system}"."{machine_name}".{generators_selector}',
f'clanInternals.machines."{system}"."{machine_name}".{files_selector}',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.publicModule',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.secretModule',
]
return all_selectors
@classmethod
def get_machine_generators(
cls: type["Generator"],
@@ -109,22 +139,9 @@ class Generator:
list[Generator]: A list of (unsorted) generators for the machine.
"""
config = nix_config()
system = config["system"]
generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}"
files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}"
# precache all machines generators and files to avoid multiple calls to nix
all_selectors = []
for machine_name in machine_names:
all_selectors += [
f'clanInternals.machines."{system}"."{machine_name}".{generators_selector}',
f'clanInternals.machines."{system}"."{machine_name}".{files_selector}',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.publicModule',
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.secretModule',
]
flake.precache(all_selectors)
flake.precache(cls.get_machine_selectors(machine_names))
generators = []

View File

@@ -3,6 +3,7 @@ import logging
import os
import re
import shlex
import traceback
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import cache
@@ -792,7 +793,7 @@ class Flake:
_cache: FlakeCache | None = field(init=False, default=None)
_path: Path | None = field(init=False, default=None)
_is_local: bool | None = field(init=False, default=None)
_cache_misses: int = field(init=False, default=0)
_cache_miss_stack_traces: list[str] = field(init=False, default_factory=list)
@classmethod
def from_json(
@@ -814,6 +815,34 @@ class Flake:
return NotImplemented
return self.identifier == other.identifier
def _record_cache_miss(self, selector_info: str) -> None:
"""Record a cache miss with its stack trace."""
stack_trace = "".join(traceback.format_stack())
self._cache_miss_stack_traces.append(f"{selector_info}\n{stack_trace}")
@property
def _cache_misses(self) -> int:
"""Get the count of cache misses from the stack trace list."""
return len(self._cache_miss_stack_traces)
def print_cache_miss_analysis(self, title: str = "Cache miss analysis") -> None:
"""Print detailed analysis of cache misses with stack traces.
Args:
title: Title for the analysis output
"""
if not self._cache_miss_stack_traces:
return
print(f"\n=== {title} ===")
print(f"Total cache misses: {len(self._cache_miss_stack_traces)}")
print("\nStack traces for all cache misses:")
for i, trace in enumerate(self._cache_miss_stack_traces, 1):
print(f"\n--- Cache miss #{i} ---")
print(trace)
print("=" * 50)
@property
def is_local(self) -> bool:
if self._is_local is None:
@@ -886,10 +915,13 @@ class Flake:
"""Invalidate the cache and reload it.
This method is used to refresh the cache by reloading it from the flake.
Also resets cache miss tracking.
"""
self.prefetch()
self._cache = FlakeCache()
# Reset cache miss tracking when invalidating cache
self._cache_miss_stack_traces.clear()
if self.hash is None:
msg = "Hash cannot be None"
raise ClanError(msg)
@@ -1063,8 +1095,10 @@ class Flake:
]
if not_fetched_selectors:
# Increment cache miss counter for each selector that wasn't cached
self._cache_misses += 1
# Record cache miss with stack trace
self._record_cache_miss(
f"Cache miss for selectors: {not_fetched_selectors}"
)
self.get_from_nix(not_fetched_selectors)
def select(
@@ -1090,7 +1124,8 @@ class Flake:
if not self._cache.is_cached(selector):
log.debug(f"(cached) $ clan select {shlex.quote(selector)}")
log.debug(f"Cache miss for {selector}")
self._cache_misses += 1
# Record cache miss with stack trace
self._record_cache_miss(f"Cache miss for selector: {selector}")
self.get_from_nix([selector])
else:
log.debug(f"$ clan select {shlex.quote(selector)}")

View File

@@ -129,10 +129,21 @@ class InventoryStore:
self._allowed_path_transforms = _allowed_path_transforms
if _keys is None:
_keys = list(InventorySnapshot.__annotations__.keys())
_keys = self.default_keys()
self._keys = _keys
@classmethod
def default_keys(cls) -> list[str]:
return list(InventorySnapshot.__annotations__.keys())
@classmethod
def default_selectors(cls) -> list[str]:
return [
f"clanInternals.inventoryClass.inventory.{key}"
for key in cls.default_keys()
]
def _load_merged_inventory(self) -> InventorySnapshot:
"""Loads the evaluated inventory.
After all merge operations with eventual nix code in buildClan.

View File

@@ -9,6 +9,7 @@ from clan_cli.vars.migration import check_can_migrate, migrate_files
from clan_lib.api import API
from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine
from clan_lib.persist.inventory_store import InventoryStore
log = logging.getLogger(__name__)
@@ -37,18 +38,22 @@ def get_generators(
if not machines:
msg = "At least one machine must be provided"
raise ClanError(msg)
all_machines = machines[0].flake.list_machines().keys()
flake = machines[0].flake
flake.precache(
InventoryStore.default_selectors()
+ Generator.get_machine_selectors(m.name for m in machines)
)
all_machines = flake.list_machines().keys()
requested_machines = [machine.name for machine in machines]
all_generators_list = Generator.get_machine_generators(
all_machines,
machines[0].flake,
flake,
include_previous_values=include_previous_values,
)
requested_generators_list = Generator.get_machine_generators(
requested_machines,
machines[0].flake,
flake,
include_previous_values=include_previous_values,
)