Merge pull request 'vars: improve generator execution pipeline' (#2046) from DavHau/clan-core:DavHau-dave into main
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from . import (
|
||||
history,
|
||||
secrets,
|
||||
state,
|
||||
vars,
|
||||
vms,
|
||||
)
|
||||
from .clan_uri import FlakeId
|
||||
@@ -32,6 +31,7 @@ from .hyperlink import help_hyperlink
|
||||
from .machines import cli as machines
|
||||
from .profiler import profile
|
||||
from .ssh import cli as ssh_cli
|
||||
from .vars import cli as vars_cli
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -293,7 +293,7 @@ For more detailed information, visit: {help_hyperlink("secrets", "https://docs.c
|
||||
),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
vars.register_parser(parser_vars)
|
||||
vars_cli.register_parser(parser_vars)
|
||||
|
||||
parser_machine = subparsers.add_parser(
|
||||
"machines",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Literal
|
||||
@@ -10,6 +12,8 @@ from clan_cli.cmd import run_no_stdout
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata
|
||||
from clan_cli.ssh import Host, parse_deployment_address
|
||||
from clan_cli.vars.public_modules import FactStoreBase
|
||||
from clan_cli.vars.secret_modules import SecretStoreBase
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,6 +94,16 @@ class Machine:
|
||||
def public_vars_module(self) -> str:
|
||||
return self.deployment["vars"]["publicModule"]
|
||||
|
||||
@cached_property
|
||||
def secret_vars_store(self) -> SecretStoreBase:
|
||||
module = importlib.import_module(self.secret_vars_module)
|
||||
return module.SecretStore(machine=self)
|
||||
|
||||
@cached_property
|
||||
def public_vars_store(self) -> FactStoreBase:
|
||||
module = importlib.import_module(self.public_vars_module)
|
||||
return module.FactStore(machine=self)
|
||||
|
||||
@property
|
||||
def facts_data(self) -> dict[str, dict[str, Any]]:
|
||||
if self.deployment["facts"]["services"]:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# !/usr/bin/env python3
|
||||
import json
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from clan_cli.machines.machines import Machine
|
||||
from clan_cli.machines import machines
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -18,6 +17,7 @@ class Prompt:
|
||||
previous_value: str | None = None
|
||||
|
||||
|
||||
# TODO: add flag 'pending' generator needs to be executed
|
||||
@dataclass
|
||||
class Generator:
|
||||
name: str
|
||||
@@ -72,7 +72,7 @@ class Var:
|
||||
|
||||
|
||||
class StoreBase(ABC):
|
||||
def __init__(self, machine: Machine) -> None:
|
||||
def __init__(self, machine: "machines.Machine") -> None:
|
||||
self.machine = machine
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -197,9 +196,9 @@ def execute_generator(
|
||||
return True
|
||||
|
||||
|
||||
def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
|
||||
def _get_subgraph(graph: dict[str, set], vertices: list[str]) -> dict[str, set]:
|
||||
visited = set()
|
||||
queue = [vertex]
|
||||
queue = vertices
|
||||
while queue:
|
||||
vertex = queue.pop(0)
|
||||
if vertex not in visited:
|
||||
@@ -208,50 +207,90 @@ def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
|
||||
return {k: v for k, v in graph.items() if k in visited}
|
||||
|
||||
|
||||
def _dependency_graph(
|
||||
machine: Machine, entry_nodes: None | list[str] = None
|
||||
) -> dict[str, set]:
|
||||
graph = {
|
||||
gen_name: set(generator["dependencies"])
|
||||
for gen_name, generator in machine.vars_generators.items()
|
||||
}
|
||||
if entry_nodes:
|
||||
return _get_subgraph(graph, entry_nodes)
|
||||
return graph
|
||||
|
||||
|
||||
def _reverse_dependency_graph(
|
||||
machine: Machine, entry_nodes: None | list[str] = None
|
||||
) -> dict[str, set]:
|
||||
graph = _dependency_graph(machine)
|
||||
reverse_graph: dict[str, set] = {gen_name: set() for gen_name in graph}
|
||||
for gen_name, dependencies in graph.items():
|
||||
for dep in dependencies:
|
||||
reverse_graph[dep].add(gen_name)
|
||||
if entry_nodes:
|
||||
return _get_subgraph(reverse_graph, entry_nodes)
|
||||
return reverse_graph
|
||||
|
||||
|
||||
def _required_generators(
|
||||
machine: Machine,
|
||||
desired_generators: list[str],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Receives list fo desired generators to update and returns list of required generators to update.
|
||||
|
||||
This is needed because some generators might depend on others, so we need to update them first.
|
||||
The returned list is sorted topologically.
|
||||
"""
|
||||
|
||||
dependency_graph = _dependency_graph(machine)
|
||||
# extract sub-graph if specific generators selected
|
||||
dependency_graph = _get_subgraph(dependency_graph, desired_generators)
|
||||
|
||||
# check if all dependencies actually exist
|
||||
for gen_name, dependencies in dependency_graph.items():
|
||||
for dep in dependencies:
|
||||
if dep not in dependency_graph:
|
||||
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
|
||||
raise ClanError(msg)
|
||||
|
||||
# ensure that all dependents are regenerated as well as their vars might depend on the current generator
|
||||
reverse_dependency_graph = _reverse_dependency_graph(machine, desired_generators)
|
||||
final_graph = _dependency_graph(
|
||||
machine, entry_nodes=list(reverse_dependency_graph.keys())
|
||||
)
|
||||
|
||||
# process generators in topological order (dependencies first)
|
||||
sorter = TopologicalSorter(final_graph)
|
||||
return list(sorter.static_order())
|
||||
|
||||
|
||||
def _generate_vars_for_machine(
|
||||
machine: Machine,
|
||||
generator_name: str | None,
|
||||
regenerate: bool,
|
||||
) -> bool:
|
||||
secret_vars_module = importlib.import_module(machine.secret_vars_module)
|
||||
secret_vars_store = secret_vars_module.SecretStore(machine=machine)
|
||||
return _generate_vars_for_machine_multi(
|
||||
machine, [generator_name] if generator_name else [], regenerate
|
||||
)
|
||||
|
||||
public_vars_module = importlib.import_module(machine.public_vars_module)
|
||||
public_vars_store = public_vars_module.FactStore(machine=machine)
|
||||
|
||||
def _generate_vars_for_machine_multi(
|
||||
machine: Machine,
|
||||
generator_names: list[str],
|
||||
regenerate: bool,
|
||||
) -> bool:
|
||||
machine_updated = False
|
||||
|
||||
if generator_name and generator_name not in machine.vars_generators:
|
||||
generators = list(machine.vars_generators.keys())
|
||||
msg = f"Could not find generator with name: {generator_name}. The following generators are available: {generators}"
|
||||
raise ClanError(msg)
|
||||
|
||||
graph = {
|
||||
gen_name: set(generator["dependencies"])
|
||||
for gen_name, generator in machine.vars_generators.items()
|
||||
}
|
||||
|
||||
# extract sub-graph if specific generator selected
|
||||
if generator_name:
|
||||
graph = _get_subgraph(graph, generator_name)
|
||||
|
||||
# check if all dependencies actually exist
|
||||
for gen_name, dependencies in graph.items():
|
||||
for dep in dependencies:
|
||||
if dep not in graph:
|
||||
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
|
||||
raise ClanError(msg)
|
||||
|
||||
# process generators in topological order
|
||||
sorter = TopologicalSorter(graph)
|
||||
for generator_name in sorter.static_order():
|
||||
generators_to_update = _required_generators(machine, generator_names)
|
||||
for generator_name in generators_to_update:
|
||||
assert generator_name is not None
|
||||
machine_updated |= execute_generator(
|
||||
machine=machine,
|
||||
generator_name=generator_name,
|
||||
regenerate=regenerate,
|
||||
secret_vars_store=secret_vars_store,
|
||||
public_vars_store=public_vars_store,
|
||||
secret_vars_store=machine.secret_vars_store,
|
||||
public_vars_store=machine.public_vars_store,
|
||||
)
|
||||
if machine_updated:
|
||||
# flush caches to make sure the new secrets are available in evaluation
|
||||
|
||||
@@ -30,8 +30,10 @@ def get_vars(machine: Machine) -> list[Var]:
|
||||
return pub_store.get_all() + sec_store.get_all()
|
||||
|
||||
|
||||
def _get_prompt_value(
|
||||
machine: Machine, generator: Generator, prompt: Prompt
|
||||
def _get_previous_value(
|
||||
machine: Machine,
|
||||
generator: Generator,
|
||||
prompt: Prompt,
|
||||
) -> str | None:
|
||||
if not prompt.has_file:
|
||||
return None
|
||||
@@ -40,10 +42,16 @@ def _get_prompt_value(
|
||||
return pub_store.get(
|
||||
generator.name, prompt.name, shared=generator.share
|
||||
).decode()
|
||||
sec_store = secret_store(machine)
|
||||
if sec_store.exists(generator.name, prompt.name, shared=generator.share):
|
||||
return sec_store.get(
|
||||
generator.name, prompt.name, shared=generator.share
|
||||
).decode()
|
||||
return None
|
||||
|
||||
|
||||
@API.register
|
||||
# TODO: use machine_name
|
||||
def get_prompts(machine: Machine) -> list[Generator]:
|
||||
generators = []
|
||||
for gen_name, generator in machine.vars_generators.items():
|
||||
@@ -61,7 +69,7 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
||||
has_file=prompt["createFile"],
|
||||
generator=gen_name,
|
||||
)
|
||||
prompt.previous_value = _get_prompt_value(machine, gen, prompt)
|
||||
prompt.previous_value = _get_previous_value(machine, gen, prompt)
|
||||
prompts.append(prompt)
|
||||
|
||||
generators.append(gen)
|
||||
@@ -69,6 +77,8 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
||||
|
||||
|
||||
# TODO: Ensure generator dependencies are met (executed in correct order etc.)
|
||||
# TODO: for missing prompts, default to existing values
|
||||
# TODO: raise error if mandatory prompt not provided
|
||||
@API.register
|
||||
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
|
||||
for update in updates:
|
||||
|
||||
@@ -27,12 +27,12 @@ def test_get_subgraph() -> None:
|
||||
"c": set(),
|
||||
"d": set(),
|
||||
}
|
||||
assert _get_subgraph(graph, "a") == {
|
||||
assert _get_subgraph(graph, ["a"]) == {
|
||||
"a": {"b", "c"},
|
||||
"b": {"c"},
|
||||
"c": set(),
|
||||
}
|
||||
assert _get_subgraph(graph, "b") == {"b": {"c"}, "c": set()}
|
||||
assert _get_subgraph(graph, ["b"]) == {"b": {"c"}, "c": set()}
|
||||
|
||||
|
||||
def test_dependencies_as_files() -> None:
|
||||
|
||||
Reference in New Issue
Block a user