diff --git a/pkgs/clan-cli/clan_cli/vars/generate.py b/pkgs/clan-cli/clan_cli/vars/generate.py index f85d9f29b..f486d1897 100644 --- a/pkgs/clan-cli/clan_cli/vars/generate.py +++ b/pkgs/clan-cli/clan_cli/vars/generate.py @@ -2,7 +2,6 @@ import argparse import logging import os import sys -from graphlib import TopologicalSorter from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -19,7 +18,10 @@ from clan_cli.machines.inventory import get_all_machines, get_selected_machines from clan_cli.machines.machines import Machine from clan_cli.nix import nix_shell -from .check import check_vars +from .graph import ( + minimal_closure, + requested_closure, +) from .prompt import ask from .public_modules import FactStoreBase from .secret_modules import SecretStoreBase @@ -95,17 +97,10 @@ def dependencies_as_dir( def execute_generator( machine: Machine, generator_name: str, - regenerate: bool, secret_vars_store: SecretStoreBase, public_vars_store: FactStoreBase, - prompt_values: dict[str, str] | None, -) -> bool: - prompt_values = {} if prompt_values is None else prompt_values - # check if all secrets exist and generate them if at least one is missing - needs_regeneration = not check_vars(machine, generator_name=generator_name) - log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}") - if not (needs_regeneration or regenerate): - return False + prompt_values: dict[str, str], +) -> None: if not isinstance(machine.flake, Path): msg = f"flake is not a Path: {machine.flake}" msg += "fact/secret generation is only supported for local flakes" @@ -188,76 +183,6 @@ def execute_generator( machine.flake_dir, f"Update facts/secrets for service {generator_name} in machine {machine.name}", ) - return True - - -def _get_subgraph(graph: dict[str, set], vertices: list[str]) -> dict[str, set]: - visited = set() - queue = vertices - while queue: - vertex = queue.pop(0) - if vertex not in visited: - visited.add(vertex) - queue.extend(graph[vertex] - visited) - return {k: v for k, v in graph.items() if k in visited} - - -def _dependency_graph( - machine: Machine, entry_nodes: None | list[str] = None -) -> dict[str, set]: - graph = { - gen_name: set(generator["dependencies"]) - for gen_name, generator in machine.vars_generators.items() - } - if entry_nodes: - return _get_subgraph(graph, entry_nodes) - return graph - - -def _reverse_dependency_graph( - machine: Machine, entry_nodes: None | list[str] = None -) -> dict[str, set]: - graph = _dependency_graph(machine) - reverse_graph: dict[str, set] = {gen_name: set() for gen_name in graph} - for gen_name, dependencies in graph.items(): - for dep in dependencies: - reverse_graph[dep].add(gen_name) - if entry_nodes: - return _get_subgraph(reverse_graph, entry_nodes) - return reverse_graph - - -def _required_generators( - machine: Machine, - desired_generators: list[str], -) -> list[str]: - """ - Receives list fo desired generators to update and returns list of required generators to update. - - This is needed because some generators might depend on others, so we need to update them first. - The returned list is sorted topologically. - """ - - dependency_graph = _dependency_graph(machine) - # extract sub-graph if specific generators selected - dependency_graph = _get_subgraph(dependency_graph, desired_generators) - - # check if all dependencies actually exist - for gen_name, dependencies in dependency_graph.items(): - for dep in dependencies: - if dep not in dependency_graph: - msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist" - raise ClanError(msg) - - # ensure that all dependents are regenerated as well as their vars might depend on the current generator - reverse_dependency_graph = _reverse_dependency_graph(machine, desired_generators) - final_graph = _dependency_graph( - machine, entry_nodes=list(reverse_dependency_graph.keys()) - ) - - # process generators in topological order (dependencies first) - sorter = TopologicalSorter(final_graph) - return list(sorter.static_order()) def _ask_prompts( @@ -276,30 +201,26 @@ def _ask_prompts( return prompt_values -def _generate_vars_for_machine_multi( +def get_closure( machine: Machine, - generator_names: list[str], + generator_name: str | None, regenerate: bool, -) -> bool: - machine_updated = False +) -> list[str]: + from .graph import Generator, all_missing_closure, full_closure - generators_to_update = _required_generators(machine, generator_names) - for generator_name in generators_to_update: - assert generator_name is not None - machine_updated |= execute_generator( - machine=machine, - generator_name=generator_name, - regenerate=regenerate, - secret_vars_store=machine.secret_vars_store, - public_vars_store=machine.public_vars_store, - prompt_values=_ask_prompts(machine, [generator_name]).get( - generator_name, {} - ), - ) - if machine_updated: - # flush caches to make sure the new secrets are available in evaluation - machine.flush_caches() - return machine_updated + vars_generators = machine.vars_generators + generators: dict[str, Generator] = { + name: Generator(name, generator["dependencies"], _machine=machine) + for name, generator in vars_generators.items() + } + if generator_name is None: # all generators selected + if regenerate: + return full_closure(generators) + return all_missing_closure(generators) + # specific generator selected + if regenerate: + return requested_closure([generator_name], generators) + return minimal_closure([generator_name], generators) def _generate_vars_for_machine( @@ -307,9 +228,21 @@ def _generate_vars_for_machine( generator_name: str | None, regenerate: bool, ) -> bool: - return _generate_vars_for_machine_multi( - machine, [generator_name] if generator_name else [], regenerate - ) + closure = get_closure(machine, generator_name, regenerate) + if len(closure) == 0: + return False + prompt_values = _ask_prompts(machine, closure) + for gen_name in closure: + execute_generator( + machine, + gen_name, + machine.secret_vars_store, + machine.public_vars_store, + prompt_values.get(gen_name, {}), + ) + # flush caches to make sure the new secrets are available in evaluation + machine.flush_caches() + return True def generate_vars( @@ -324,6 +257,7 @@ def generate_vars( was_regenerated |= _generate_vars_for_machine( machine, generator_name, regenerate ) + machine.flush_caches() except Exception as exc: log.exception(f"Failed to generate facts for {machine.name}") errors += [exc] diff --git a/pkgs/clan-cli/clan_cli/vars/graph.py b/pkgs/clan-cli/clan_cli/vars/graph.py new file mode 100644 index 000000000..14841f681 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/vars/graph.py @@ -0,0 +1,108 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from functools import cached_property +from graphlib import TopologicalSorter + +from clan_cli.machines.machines import Machine + +from .check import check_vars + + +@dataclass +class Generator: + name: str + dependencies: list[str] + _machine: Machine + + @cached_property + def exists(self) -> bool: + return check_vars(self._machine, generator_name=self.name) + + +def missing_dependency_closure( + requested_generators: Iterable[str], generators: dict +) -> set[str]: + closure = set(requested_generators) + # extend the graph to include all dependencies which are not on disk + dep_closure = set() + queue = list(closure) + while queue: + gen_name = queue.pop(0) + for dep in generators[gen_name].dependencies: + if dep not in closure and not generators[dep].exists: + dep_closure.add(dep) + queue.append(dep) + return dep_closure + + +def add_missing_dependencies( + requested_generators: Iterable[str], generators: dict +) -> set[str]: + closure = set(requested_generators) + return missing_dependency_closure(closure, generators) | closure + + +def add_dependents(requested_generators: Iterable[str], generators: dict) -> set[str]: + closure = set(requested_generators) + # build reverse dependency graph (graph of dependents) + dependents_graph: dict[str, set[str]] = {} + for gen_name, gen in generators.items(): + for dep in gen.dependencies: + if dep not in dependents_graph: + dependents_graph[dep] = set() + dependents_graph[dep].add(gen_name) + # extend the graph to include all dependents of the current closure + queue = list(closure) + while queue: + gen_name = queue.pop(0) + for dep in dependents_graph.get(gen_name, []): + if dep not in closure: + closure.add(dep) + queue.append(dep) + return closure + + +def toposort_closure(_closure: Iterable[str], generators: dict) -> list[str]: + closure = set(_closure) + # return the topological sorted list of generators to execute + final_dep_graph = {} + for gen_name in sorted(closure): + deps = set(generators[gen_name].dependencies) & closure + final_dep_graph[gen_name] = deps + sorter = TopologicalSorter(final_dep_graph) + result = list(sorter.static_order()) + return result + + +# all generators in topological order +def full_closure(generators: dict) -> list[str]: + return toposort_closure(generators.keys(), generators) + + +# just the missing generators including their dependents +def all_missing_closure(generators: dict) -> list[str]: + # collect all generators that are missing from disk + closure = {gen_name for gen_name, gen in generators.items() if not gen.exists} + closure = add_dependents(closure, generators) + return toposort_closure(closure, generators) + + +# only a selected list of generators including their missing dependencies and their dependents +def requested_closure(requested_generators: list[str], generators: dict) -> list[str]: + closure = set(requested_generators) + # extend the graph to include all dependencies which are not on disk + closure = add_missing_dependencies(closure, generators) + closure = add_dependents(closure, generators) + return toposort_closure(closure, generators) + + +# just enough to ensure that the list of selected generators are in a consistent state. +# empty if nothing is missing. +def minimal_closure(requested_generators: list[str], generators: dict) -> list[str]: + closure = set(requested_generators) + final_closure = missing_dependency_closure(closure, generators) + # add requested generators if not already exist + for gen_name in closure: + if not generators[gen_name].exists: + final_closure.add(gen_name) + return toposort_closure(final_closure, generators) diff --git a/pkgs/clan-cli/clan_cli/vars/list.py b/pkgs/clan-cli/clan_cli/vars/list.py index 7e677bcea..699c42f7f 100644 --- a/pkgs/clan-cli/clan_cli/vars/list.py +++ b/pkgs/clan-cli/clan_cli/vars/list.py @@ -85,7 +85,6 @@ def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None: execute_generator( machine, update.generator, - regenerate=True, secret_vars_store=secret_store(machine), public_vars_store=public_store(machine), prompt_values=update.prompt_values, diff --git a/pkgs/clan-cli/tests/test_vars.py b/pkgs/clan-cli/tests/test_vars.py index d045f8e55..a666d4e1c 100644 --- a/pkgs/clan-cli/tests/test_vars.py +++ b/pkgs/clan-cli/tests/test_vars.py @@ -1,4 +1,5 @@ import subprocess +from dataclasses import dataclass from io import StringIO from pathlib import Path from tempfile import TemporaryDirectory @@ -18,23 +19,6 @@ from helpers.nixos_config import nested_dict from root import CLAN_CORE -def test_get_subgraph() -> None: - from clan_cli.vars.generate import _get_subgraph - - graph = { - "a": {"b", "c"}, - "b": {"c"}, - "c": set(), - "d": set(), - } - assert _get_subgraph(graph, ["a"]) == { - "a": {"b", "c"}, - "b": {"c"}, - "c": set(), - } - assert _get_subgraph(graph, ["b"]) == {"b": {"c"}, "c": set()} - - def test_dependencies_as_files() -> None: from clan_cli.vars.generate import dependencies_as_dir @@ -63,6 +47,34 @@ def test_dependencies_as_files() -> None: assert (dep_tmpdir / "gen_2" / "var_2b").stat().st_mode & 0o777 == 0o600 +def test_required_generators() -> None: + from clan_cli.vars.graph import all_missing_closure, requested_closure + + @dataclass + class Generator: + dependencies: list[str] + exists: bool # result is already on disk + + generators = { + "gen_1": Generator([], True), + "gen_2": Generator(["gen_1"], False), + "gen_2a": Generator(["gen_2"], False), + "gen_2b": Generator(["gen_2"], True), + } + + assert requested_closure(["gen_1"], generators) == [ + "gen_1", + "gen_2", + "gen_2a", + "gen_2b", + ] + assert requested_closure(["gen_2"], generators) == ["gen_2", "gen_2a", "gen_2b"] + assert requested_closure(["gen_2a"], generators) == ["gen_2", "gen_2a", "gen_2b"] + assert requested_closure(["gen_2b"], generators) == ["gen_2", "gen_2a", "gen_2b"] + + assert all_missing_closure(generators) == ["gen_2", "gen_2a", "gen_2b"] + + @pytest.mark.impure def test_generate_public_var( monkeypatch: pytest.MonkeyPatch,