Merge pull request 'vars: optimize generate - reduce cache misses' (#5348) from dave into main
Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/5348
This commit is contained in:
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -28,6 +29,17 @@ from clan_lib.vars.generate import (
|
||||
)
|
||||
|
||||
|
||||
def invalidate_flake_cache(flake_path: Path) -> None:
|
||||
"""Force flake cache invalidation by modifying the git repository.
|
||||
|
||||
This adds a dummy file to git which changes the NAR hash of the flake,
|
||||
forcing a cache invalidation.
|
||||
"""
|
||||
dummy_file = flake_path / f".cache_invalidation_{time.time()}"
|
||||
dummy_file.write_text("invalidate")
|
||||
run(["git", "add", str(dummy_file)])
|
||||
|
||||
|
||||
def test_dependencies_as_files(temp_dir: Path) -> None:
|
||||
decrypted_dependencies = {
|
||||
"gen_1": {
|
||||
@@ -1264,37 +1276,71 @@ def test_share_mode_switch_regenerates_secret(
|
||||
|
||||
|
||||
@pytest.mark.with_core
|
||||
def test_cache_misses_for_vars_list(
|
||||
def test_cache_misses_for_vars_operations(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
flake: ClanFlake,
|
||||
) -> None:
|
||||
"""Test that listing vars results in exactly one cache miss."""
|
||||
"""Test that vars operations result in minimal cache misses."""
|
||||
config = flake.machines["my_machine"]
|
||||
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
|
||||
|
||||
# Set up a simple generator
|
||||
# Set up a simple generator with a public value
|
||||
my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"]
|
||||
my_generator["files"]["my_value"]["secret"] = False
|
||||
my_generator["script"] = 'echo -n "test" > "$out"/my_value'
|
||||
my_generator["script"] = 'echo -n "test_value" > "$out"/my_value'
|
||||
|
||||
flake.refresh()
|
||||
monkeypatch.chdir(flake.path)
|
||||
|
||||
# # Generate the vars first
|
||||
# cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
|
||||
|
||||
# Create a fresh machine object to ensure clean cache state
|
||||
machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
|
||||
|
||||
# Record initial cache misses
|
||||
initial_cache_misses = machine.flake._cache_misses
|
||||
# Test 1: Running vars generate with a fresh cache should result in exactly 3 cache misses
|
||||
# Expected cache misses:
|
||||
# 1. One for getting the list of generators
|
||||
# 2. One for getting the final script of our test generator (my_generator)
|
||||
# 3. One for getting the final script of the state version generator (added by default)
|
||||
# TODO: The third cache miss is undesired in tests. disable state version module for tests
|
||||
|
||||
run_generators(
|
||||
machines=[machine],
|
||||
generators=None, # Generate all
|
||||
)
|
||||
|
||||
# Print stack traces if we have more than 3 cache misses
|
||||
if machine.flake._cache_misses != 3:
|
||||
machine.flake.print_cache_miss_analysis(
|
||||
title="Cache miss analysis for vars generate"
|
||||
)
|
||||
|
||||
assert machine.flake._cache_misses == 3, (
|
||||
f"Expected exactly 3 cache misses for vars generate, got {machine.flake._cache_misses}"
|
||||
)
|
||||
|
||||
# Verify the value was generated correctly
|
||||
var_value = get_machine_var(machine, "my_generator/my_value")
|
||||
assert var_value.printable_value == "test_value"
|
||||
|
||||
# Test 2: List all vars should result in exactly 1 cache miss
|
||||
# Force cache invalidation (this also resets cache miss tracking)
|
||||
invalidate_flake_cache(flake.path)
|
||||
machine.flake.invalidate_cache()
|
||||
|
||||
# List all vars - this should result in exactly one cache miss
|
||||
stringify_all_vars(machine)
|
||||
assert machine.flake._cache_misses == 1, (
|
||||
f"Expected exactly 1 cache miss for vars list, got {machine.flake._cache_misses}"
|
||||
)
|
||||
|
||||
# Assert we had exactly one cache miss for the efficient lookup
|
||||
assert machine.flake._cache_misses == initial_cache_misses + 1, (
|
||||
f"Expected exactly 1 cache miss for vars list, got {machine.flake._cache_misses - initial_cache_misses}"
|
||||
# Test 3: Getting a specific var with a fresh cache should result in exactly 1 cache miss
|
||||
# Force cache invalidation (this also resets cache miss tracking)
|
||||
invalidate_flake_cache(flake.path)
|
||||
machine.flake.invalidate_cache()
|
||||
|
||||
var_value = get_machine_var(machine, "my_generator/my_value")
|
||||
assert var_value.printable_value == "test_value"
|
||||
|
||||
assert machine.flake._cache_misses == 1, (
|
||||
f"Expected exactly 1 cache miss for vars get with fresh cache, got {machine.flake._cache_misses}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -91,6 +91,36 @@ class Generator:
|
||||
self
|
||||
) and self._public_store.hash_is_valid(self)
|
||||
|
||||
@classmethod
|
||||
def get_machine_selectors(
|
||||
cls: type["Generator"],
|
||||
machine_names: Iterable[str],
|
||||
) -> list[str]:
|
||||
"""Get all selectors needed to fetch generators and files for the given machines.
|
||||
|
||||
Args:
|
||||
machine_names: The names of the machines.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of selectors to fetch all generators and files for the machines.
|
||||
|
||||
"""
|
||||
config = nix_config()
|
||||
system = config["system"]
|
||||
|
||||
generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}"
|
||||
files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}"
|
||||
|
||||
all_selectors = []
|
||||
for machine_name in machine_names:
|
||||
all_selectors += [
|
||||
f'clanInternals.machines."{system}"."{machine_name}".{generators_selector}',
|
||||
f'clanInternals.machines."{system}"."{machine_name}".{files_selector}',
|
||||
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.publicModule',
|
||||
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.secretModule',
|
||||
]
|
||||
return all_selectors
|
||||
|
||||
@classmethod
|
||||
def get_machine_generators(
|
||||
cls: type["Generator"],
|
||||
@@ -109,22 +139,9 @@ class Generator:
|
||||
list[Generator]: A list of (unsorted) generators for the machine.
|
||||
|
||||
"""
|
||||
config = nix_config()
|
||||
system = config["system"]
|
||||
|
||||
generators_selector = "config.clan.core.vars.generators.*.{share,dependencies,migrateFact,prompts,validationHash}"
|
||||
files_selector = "config.clan.core.vars.generators.*.files.*.{secret,deploy,owner,group,mode,neededFor}"
|
||||
|
||||
# precache all machines generators and files to avoid multiple calls to nix
|
||||
all_selectors = []
|
||||
for machine_name in machine_names:
|
||||
all_selectors += [
|
||||
f'clanInternals.machines."{system}"."{machine_name}".{generators_selector}',
|
||||
f'clanInternals.machines."{system}"."{machine_name}".{files_selector}',
|
||||
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.publicModule',
|
||||
f'clanInternals.machines."{system}"."{machine_name}".config.clan.core.vars.settings.secretModule',
|
||||
]
|
||||
flake.precache(all_selectors)
|
||||
flake.precache(cls.get_machine_selectors(machine_names))
|
||||
|
||||
generators = []
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import traceback
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
@@ -792,7 +793,7 @@ class Flake:
|
||||
_cache: FlakeCache | None = field(init=False, default=None)
|
||||
_path: Path | None = field(init=False, default=None)
|
||||
_is_local: bool | None = field(init=False, default=None)
|
||||
_cache_misses: int = field(init=False, default=0)
|
||||
_cache_miss_stack_traces: list[str] = field(init=False, default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_json(
|
||||
@@ -814,6 +815,34 @@ class Flake:
|
||||
return NotImplemented
|
||||
return self.identifier == other.identifier
|
||||
|
||||
def _record_cache_miss(self, selector_info: str) -> None:
|
||||
"""Record a cache miss with its stack trace."""
|
||||
stack_trace = "".join(traceback.format_stack())
|
||||
self._cache_miss_stack_traces.append(f"{selector_info}\n{stack_trace}")
|
||||
|
||||
@property
|
||||
def _cache_misses(self) -> int:
|
||||
"""Get the count of cache misses from the stack trace list."""
|
||||
return len(self._cache_miss_stack_traces)
|
||||
|
||||
def print_cache_miss_analysis(self, title: str = "Cache miss analysis") -> None:
|
||||
"""Print detailed analysis of cache misses with stack traces.
|
||||
|
||||
Args:
|
||||
title: Title for the analysis output
|
||||
|
||||
"""
|
||||
if not self._cache_miss_stack_traces:
|
||||
return
|
||||
|
||||
print(f"\n=== {title} ===")
|
||||
print(f"Total cache misses: {len(self._cache_miss_stack_traces)}")
|
||||
print("\nStack traces for all cache misses:")
|
||||
for i, trace in enumerate(self._cache_miss_stack_traces, 1):
|
||||
print(f"\n--- Cache miss #{i} ---")
|
||||
print(trace)
|
||||
print("=" * 50)
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
if self._is_local is None:
|
||||
@@ -886,10 +915,13 @@ class Flake:
|
||||
"""Invalidate the cache and reload it.
|
||||
|
||||
This method is used to refresh the cache by reloading it from the flake.
|
||||
Also resets cache miss tracking.
|
||||
"""
|
||||
self.prefetch()
|
||||
|
||||
self._cache = FlakeCache()
|
||||
# Reset cache miss tracking when invalidating cache
|
||||
self._cache_miss_stack_traces.clear()
|
||||
if self.hash is None:
|
||||
msg = "Hash cannot be None"
|
||||
raise ClanError(msg)
|
||||
@@ -1063,8 +1095,10 @@ class Flake:
|
||||
]
|
||||
|
||||
if not_fetched_selectors:
|
||||
# Increment cache miss counter for each selector that wasn't cached
|
||||
self._cache_misses += 1
|
||||
# Record cache miss with stack trace
|
||||
self._record_cache_miss(
|
||||
f"Cache miss for selectors: {not_fetched_selectors}"
|
||||
)
|
||||
self.get_from_nix(not_fetched_selectors)
|
||||
|
||||
def select(
|
||||
@@ -1090,7 +1124,8 @@ class Flake:
|
||||
if not self._cache.is_cached(selector):
|
||||
log.debug(f"(cached) $ clan select {shlex.quote(selector)}")
|
||||
log.debug(f"Cache miss for {selector}")
|
||||
self._cache_misses += 1
|
||||
# Record cache miss with stack trace
|
||||
self._record_cache_miss(f"Cache miss for selector: {selector}")
|
||||
self.get_from_nix([selector])
|
||||
else:
|
||||
log.debug(f"$ clan select {shlex.quote(selector)}")
|
||||
|
||||
@@ -129,10 +129,21 @@ class InventoryStore:
|
||||
self._allowed_path_transforms = _allowed_path_transforms
|
||||
|
||||
if _keys is None:
|
||||
_keys = list(InventorySnapshot.__annotations__.keys())
|
||||
_keys = self.default_keys()
|
||||
|
||||
self._keys = _keys
|
||||
|
||||
@classmethod
|
||||
def default_keys(cls) -> list[str]:
|
||||
return list(InventorySnapshot.__annotations__.keys())
|
||||
|
||||
@classmethod
|
||||
def default_selectors(cls) -> list[str]:
|
||||
return [
|
||||
f"clanInternals.inventoryClass.inventory.{key}"
|
||||
for key in cls.default_keys()
|
||||
]
|
||||
|
||||
def _load_merged_inventory(self) -> InventorySnapshot:
|
||||
"""Loads the evaluated inventory.
|
||||
After all merge operations with eventual nix code in buildClan.
|
||||
|
||||
@@ -9,6 +9,7 @@ from clan_cli.vars.migration import check_can_migrate, migrate_files
|
||||
from clan_lib.api import API
|
||||
from clan_lib.errors import ClanError
|
||||
from clan_lib.machines.machines import Machine
|
||||
from clan_lib.persist.inventory_store import InventoryStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,18 +38,22 @@ def get_generators(
|
||||
if not machines:
|
||||
msg = "At least one machine must be provided"
|
||||
raise ClanError(msg)
|
||||
|
||||
all_machines = machines[0].flake.list_machines().keys()
|
||||
flake = machines[0].flake
|
||||
flake.precache(
|
||||
InventoryStore.default_selectors()
|
||||
+ Generator.get_machine_selectors(m.name for m in machines)
|
||||
)
|
||||
all_machines = flake.list_machines().keys()
|
||||
requested_machines = [machine.name for machine in machines]
|
||||
|
||||
all_generators_list = Generator.get_machine_generators(
|
||||
all_machines,
|
||||
machines[0].flake,
|
||||
flake,
|
||||
include_previous_values=include_previous_values,
|
||||
)
|
||||
requested_generators_list = Generator.get_machine_generators(
|
||||
requested_machines,
|
||||
machines[0].flake,
|
||||
flake,
|
||||
include_previous_values=include_previous_values,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user