Merge pull request 'vars: improve generator execution pipeline' (#2046) from DavHau/clan-core:DavHau-dave into main

This commit is contained in:
clan-bot
2024-09-06 13:43:03 +00:00
8 changed files with 108 additions and 43 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import subprocess import subprocess
from pathlib import Path from pathlib import Path

View File

@@ -20,7 +20,6 @@ from . import (
history, history,
secrets, secrets,
state, state,
vars,
vms, vms,
) )
from .clan_uri import FlakeId from .clan_uri import FlakeId
@@ -32,6 +31,7 @@ from .hyperlink import help_hyperlink
from .machines import cli as machines from .machines import cli as machines
from .profiler import profile from .profiler import profile
from .ssh import cli as ssh_cli from .ssh import cli as ssh_cli
from .vars import cli as vars_cli
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -293,7 +293,7 @@ For more detailed information, visit: {help_hyperlink("secrets", "https://docs.c
), ),
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
vars.register_parser(parser_vars) vars_cli.register_parser(parser_vars)
parser_machine = subparsers.add_parser( parser_machine = subparsers.add_parser(
"machines", "machines",

View File

@@ -1,6 +1,8 @@
import importlib
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Literal from typing import Any, Literal
@@ -10,6 +12,8 @@ from clan_cli.cmd import run_no_stdout
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata
from clan_cli.ssh import Host, parse_deployment_address from clan_cli.ssh import Host, parse_deployment_address
from clan_cli.vars.public_modules import FactStoreBase
from clan_cli.vars.secret_modules import SecretStoreBase
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -90,6 +94,16 @@ class Machine:
def public_vars_module(self) -> str: def public_vars_module(self) -> str:
return self.deployment["vars"]["publicModule"] return self.deployment["vars"]["publicModule"]
@cached_property
def secret_vars_store(self) -> SecretStoreBase:
module = importlib.import_module(self.secret_vars_module)
return module.SecretStore(machine=self)
@cached_property
def public_vars_store(self) -> FactStoreBase:
module = importlib.import_module(self.public_vars_module)
return module.FactStore(machine=self)
@property @property
def facts_data(self) -> dict[str, dict[str, Any]]: def facts_data(self) -> dict[str, dict[str, Any]]:
if self.deployment["facts"]["services"]: if self.deployment["facts"]["services"]:

View File

@@ -1,11 +1,10 @@
# !/usr/bin/env python3
import json import json
import shutil import shutil
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from clan_cli.machines.machines import Machine from clan_cli.machines import machines
@dataclass @dataclass
@@ -18,6 +17,7 @@ class Prompt:
previous_value: str | None = None previous_value: str | None = None
# TODO: add flag 'pending' generator needs to be executed
@dataclass @dataclass
class Generator: class Generator:
name: str name: str
@@ -72,7 +72,7 @@ class Var:
class StoreBase(ABC): class StoreBase(ABC):
def __init__(self, machine: Machine) -> None: def __init__(self, machine: "machines.Machine") -> None:
self.machine = machine self.machine = machine
@property @property

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import importlib
import logging import logging
import os import os
import sys import sys
@@ -197,9 +196,9 @@ def execute_generator(
return True return True
def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]: def _get_subgraph(graph: dict[str, set], vertices: list[str]) -> dict[str, set]:
visited = set() visited = set()
queue = [vertex] queue = vertices
while queue: while queue:
vertex = queue.pop(0) vertex = queue.pop(0)
if vertex not in visited: if vertex not in visited:
@@ -208,50 +207,90 @@ def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
return {k: v for k, v in graph.items() if k in visited} return {k: v for k, v in graph.items() if k in visited}
def _dependency_graph(
machine: Machine, entry_nodes: None | list[str] = None
) -> dict[str, set]:
graph = {
gen_name: set(generator["dependencies"])
for gen_name, generator in machine.vars_generators.items()
}
if entry_nodes:
return _get_subgraph(graph, entry_nodes)
return graph
def _reverse_dependency_graph(
machine: Machine, entry_nodes: None | list[str] = None
) -> dict[str, set]:
graph = _dependency_graph(machine)
reverse_graph: dict[str, set] = {gen_name: set() for gen_name in graph}
for gen_name, dependencies in graph.items():
for dep in dependencies:
reverse_graph[dep].add(gen_name)
if entry_nodes:
return _get_subgraph(reverse_graph, entry_nodes)
return reverse_graph
def _required_generators(
machine: Machine,
desired_generators: list[str],
) -> list[str]:
"""
Receives list fo desired generators to update and returns list of required generators to update.
This is needed because some generators might depend on others, so we need to update them first.
The returned list is sorted topologically.
"""
dependency_graph = _dependency_graph(machine)
# extract sub-graph if specific generators selected
dependency_graph = _get_subgraph(dependency_graph, desired_generators)
# check if all dependencies actually exist
for gen_name, dependencies in dependency_graph.items():
for dep in dependencies:
if dep not in dependency_graph:
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
raise ClanError(msg)
# ensure that all dependents are regenerated as well as their vars might depend on the current generator
reverse_dependency_graph = _reverse_dependency_graph(machine, desired_generators)
final_graph = _dependency_graph(
machine, entry_nodes=list(reverse_dependency_graph.keys())
)
# process generators in topological order (dependencies first)
sorter = TopologicalSorter(final_graph)
return list(sorter.static_order())
def _generate_vars_for_machine( def _generate_vars_for_machine(
machine: Machine, machine: Machine,
generator_name: str | None, generator_name: str | None,
regenerate: bool, regenerate: bool,
) -> bool: ) -> bool:
secret_vars_module = importlib.import_module(machine.secret_vars_module) return _generate_vars_for_machine_multi(
secret_vars_store = secret_vars_module.SecretStore(machine=machine) machine, [generator_name] if generator_name else [], regenerate
)
public_vars_module = importlib.import_module(machine.public_vars_module)
public_vars_store = public_vars_module.FactStore(machine=machine)
def _generate_vars_for_machine_multi(
machine: Machine,
generator_names: list[str],
regenerate: bool,
) -> bool:
machine_updated = False machine_updated = False
if generator_name and generator_name not in machine.vars_generators: generators_to_update = _required_generators(machine, generator_names)
generators = list(machine.vars_generators.keys()) for generator_name in generators_to_update:
msg = f"Could not find generator with name: {generator_name}. The following generators are available: {generators}"
raise ClanError(msg)
graph = {
gen_name: set(generator["dependencies"])
for gen_name, generator in machine.vars_generators.items()
}
# extract sub-graph if specific generator selected
if generator_name:
graph = _get_subgraph(graph, generator_name)
# check if all dependencies actually exist
for gen_name, dependencies in graph.items():
for dep in dependencies:
if dep not in graph:
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
raise ClanError(msg)
# process generators in topological order
sorter = TopologicalSorter(graph)
for generator_name in sorter.static_order():
assert generator_name is not None assert generator_name is not None
machine_updated |= execute_generator( machine_updated |= execute_generator(
machine=machine, machine=machine,
generator_name=generator_name, generator_name=generator_name,
regenerate=regenerate, regenerate=regenerate,
secret_vars_store=secret_vars_store, secret_vars_store=machine.secret_vars_store,
public_vars_store=public_vars_store, public_vars_store=machine.public_vars_store,
) )
if machine_updated: if machine_updated:
# flush caches to make sure the new secrets are available in evaluation # flush caches to make sure the new secrets are available in evaluation

View File

@@ -30,8 +30,10 @@ def get_vars(machine: Machine) -> list[Var]:
return pub_store.get_all() + sec_store.get_all() return pub_store.get_all() + sec_store.get_all()
def _get_prompt_value( def _get_previous_value(
machine: Machine, generator: Generator, prompt: Prompt machine: Machine,
generator: Generator,
prompt: Prompt,
) -> str | None: ) -> str | None:
if not prompt.has_file: if not prompt.has_file:
return None return None
@@ -40,10 +42,16 @@ def _get_prompt_value(
return pub_store.get( return pub_store.get(
generator.name, prompt.name, shared=generator.share generator.name, prompt.name, shared=generator.share
).decode() ).decode()
sec_store = secret_store(machine)
if sec_store.exists(generator.name, prompt.name, shared=generator.share):
return sec_store.get(
generator.name, prompt.name, shared=generator.share
).decode()
return None return None
@API.register @API.register
# TODO: use machine_name
def get_prompts(machine: Machine) -> list[Generator]: def get_prompts(machine: Machine) -> list[Generator]:
generators = [] generators = []
for gen_name, generator in machine.vars_generators.items(): for gen_name, generator in machine.vars_generators.items():
@@ -61,7 +69,7 @@ def get_prompts(machine: Machine) -> list[Generator]:
has_file=prompt["createFile"], has_file=prompt["createFile"],
generator=gen_name, generator=gen_name,
) )
prompt.previous_value = _get_prompt_value(machine, gen, prompt) prompt.previous_value = _get_previous_value(machine, gen, prompt)
prompts.append(prompt) prompts.append(prompt)
generators.append(gen) generators.append(gen)
@@ -69,6 +77,8 @@ def get_prompts(machine: Machine) -> list[Generator]:
# TODO: Ensure generator dependencies are met (executed in correct order etc.) # TODO: Ensure generator dependencies are met (executed in correct order etc.)
# TODO: for missing prompts, default to existing values
# TODO: raise error if mandatory prompt not provided
@API.register @API.register
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None: def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
for update in updates: for update in updates:

View File

@@ -27,12 +27,12 @@ def test_get_subgraph() -> None:
"c": set(), "c": set(),
"d": set(), "d": set(),
} }
assert _get_subgraph(graph, "a") == { assert _get_subgraph(graph, ["a"]) == {
"a": {"b", "c"}, "a": {"b", "c"},
"b": {"c"}, "b": {"c"},
"c": set(), "c": set(),
} }
assert _get_subgraph(graph, "b") == {"b": {"c"}, "c": set()} assert _get_subgraph(graph, ["b"]) == {"b": {"c"}, "c": set()}
def test_dependencies_as_files() -> None: def test_dependencies_as_files() -> None: