Merge pull request 'vars: improve generator execution pipeline' (#2046) from DavHau/clan-core:DavHau-dave into main
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from . import (
|
|||||||
history,
|
history,
|
||||||
secrets,
|
secrets,
|
||||||
state,
|
state,
|
||||||
vars,
|
|
||||||
vms,
|
vms,
|
||||||
)
|
)
|
||||||
from .clan_uri import FlakeId
|
from .clan_uri import FlakeId
|
||||||
@@ -32,6 +31,7 @@ from .hyperlink import help_hyperlink
|
|||||||
from .machines import cli as machines
|
from .machines import cli as machines
|
||||||
from .profiler import profile
|
from .profiler import profile
|
||||||
from .ssh import cli as ssh_cli
|
from .ssh import cli as ssh_cli
|
||||||
|
from .vars import cli as vars_cli
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -293,7 +293,7 @@ For more detailed information, visit: {help_hyperlink("secrets", "https://docs.c
|
|||||||
),
|
),
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
vars.register_parser(parser_vars)
|
vars_cli.register_parser(parser_vars)
|
||||||
|
|
||||||
parser_machine = subparsers.add_parser(
|
parser_machine = subparsers.add_parser(
|
||||||
"machines",
|
"machines",
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@@ -10,6 +12,8 @@ from clan_cli.cmd import run_no_stdout
|
|||||||
from clan_cli.errors import ClanError
|
from clan_cli.errors import ClanError
|
||||||
from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata
|
from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata
|
||||||
from clan_cli.ssh import Host, parse_deployment_address
|
from clan_cli.ssh import Host, parse_deployment_address
|
||||||
|
from clan_cli.vars.public_modules import FactStoreBase
|
||||||
|
from clan_cli.vars.secret_modules import SecretStoreBase
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -90,6 +94,16 @@ class Machine:
|
|||||||
def public_vars_module(self) -> str:
|
def public_vars_module(self) -> str:
|
||||||
return self.deployment["vars"]["publicModule"]
|
return self.deployment["vars"]["publicModule"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def secret_vars_store(self) -> SecretStoreBase:
|
||||||
|
module = importlib.import_module(self.secret_vars_module)
|
||||||
|
return module.SecretStore(machine=self)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def public_vars_store(self) -> FactStoreBase:
|
||||||
|
module = importlib.import_module(self.public_vars_module)
|
||||||
|
return module.FactStore(machine=self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def facts_data(self) -> dict[str, dict[str, Any]]:
|
def facts_data(self) -> dict[str, dict[str, Any]]:
|
||||||
if self.deployment["facts"]["services"]:
|
if self.deployment["facts"]["services"]:
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
# !/usr/bin/env python3
|
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from clan_cli.machines.machines import Machine
|
from clan_cli.machines import machines
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -18,6 +17,7 @@ class Prompt:
|
|||||||
previous_value: str | None = None
|
previous_value: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add flag 'pending' generator needs to be executed
|
||||||
@dataclass
|
@dataclass
|
||||||
class Generator:
|
class Generator:
|
||||||
name: str
|
name: str
|
||||||
@@ -72,7 +72,7 @@ class Var:
|
|||||||
|
|
||||||
|
|
||||||
class StoreBase(ABC):
|
class StoreBase(ABC):
|
||||||
def __init__(self, machine: Machine) -> None:
|
def __init__(self, machine: "machines.Machine") -> None:
|
||||||
self.machine = machine
|
self.machine = machine
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import importlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -197,9 +196,9 @@ def execute_generator(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
|
def _get_subgraph(graph: dict[str, set], vertices: list[str]) -> dict[str, set]:
|
||||||
visited = set()
|
visited = set()
|
||||||
queue = [vertex]
|
queue = vertices
|
||||||
while queue:
|
while queue:
|
||||||
vertex = queue.pop(0)
|
vertex = queue.pop(0)
|
||||||
if vertex not in visited:
|
if vertex not in visited:
|
||||||
@@ -208,50 +207,90 @@ def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
|
|||||||
return {k: v for k, v in graph.items() if k in visited}
|
return {k: v for k, v in graph.items() if k in visited}
|
||||||
|
|
||||||
|
|
||||||
|
def _dependency_graph(
|
||||||
|
machine: Machine, entry_nodes: None | list[str] = None
|
||||||
|
) -> dict[str, set]:
|
||||||
|
graph = {
|
||||||
|
gen_name: set(generator["dependencies"])
|
||||||
|
for gen_name, generator in machine.vars_generators.items()
|
||||||
|
}
|
||||||
|
if entry_nodes:
|
||||||
|
return _get_subgraph(graph, entry_nodes)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _reverse_dependency_graph(
|
||||||
|
machine: Machine, entry_nodes: None | list[str] = None
|
||||||
|
) -> dict[str, set]:
|
||||||
|
graph = _dependency_graph(machine)
|
||||||
|
reverse_graph: dict[str, set] = {gen_name: set() for gen_name in graph}
|
||||||
|
for gen_name, dependencies in graph.items():
|
||||||
|
for dep in dependencies:
|
||||||
|
reverse_graph[dep].add(gen_name)
|
||||||
|
if entry_nodes:
|
||||||
|
return _get_subgraph(reverse_graph, entry_nodes)
|
||||||
|
return reverse_graph
|
||||||
|
|
||||||
|
|
||||||
|
def _required_generators(
|
||||||
|
machine: Machine,
|
||||||
|
desired_generators: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Receives list fo desired generators to update and returns list of required generators to update.
|
||||||
|
|
||||||
|
This is needed because some generators might depend on others, so we need to update them first.
|
||||||
|
The returned list is sorted topologically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dependency_graph = _dependency_graph(machine)
|
||||||
|
# extract sub-graph if specific generators selected
|
||||||
|
dependency_graph = _get_subgraph(dependency_graph, desired_generators)
|
||||||
|
|
||||||
|
# check if all dependencies actually exist
|
||||||
|
for gen_name, dependencies in dependency_graph.items():
|
||||||
|
for dep in dependencies:
|
||||||
|
if dep not in dependency_graph:
|
||||||
|
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
# ensure that all dependents are regenerated as well as their vars might depend on the current generator
|
||||||
|
reverse_dependency_graph = _reverse_dependency_graph(machine, desired_generators)
|
||||||
|
final_graph = _dependency_graph(
|
||||||
|
machine, entry_nodes=list(reverse_dependency_graph.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
# process generators in topological order (dependencies first)
|
||||||
|
sorter = TopologicalSorter(final_graph)
|
||||||
|
return list(sorter.static_order())
|
||||||
|
|
||||||
|
|
||||||
def _generate_vars_for_machine(
|
def _generate_vars_for_machine(
|
||||||
machine: Machine,
|
machine: Machine,
|
||||||
generator_name: str | None,
|
generator_name: str | None,
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
secret_vars_module = importlib.import_module(machine.secret_vars_module)
|
return _generate_vars_for_machine_multi(
|
||||||
secret_vars_store = secret_vars_module.SecretStore(machine=machine)
|
machine, [generator_name] if generator_name else [], regenerate
|
||||||
|
)
|
||||||
|
|
||||||
public_vars_module = importlib.import_module(machine.public_vars_module)
|
|
||||||
public_vars_store = public_vars_module.FactStore(machine=machine)
|
|
||||||
|
|
||||||
|
def _generate_vars_for_machine_multi(
|
||||||
|
machine: Machine,
|
||||||
|
generator_names: list[str],
|
||||||
|
regenerate: bool,
|
||||||
|
) -> bool:
|
||||||
machine_updated = False
|
machine_updated = False
|
||||||
|
|
||||||
if generator_name and generator_name not in machine.vars_generators:
|
generators_to_update = _required_generators(machine, generator_names)
|
||||||
generators = list(machine.vars_generators.keys())
|
for generator_name in generators_to_update:
|
||||||
msg = f"Could not find generator with name: {generator_name}. The following generators are available: {generators}"
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
graph = {
|
|
||||||
gen_name: set(generator["dependencies"])
|
|
||||||
for gen_name, generator in machine.vars_generators.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# extract sub-graph if specific generator selected
|
|
||||||
if generator_name:
|
|
||||||
graph = _get_subgraph(graph, generator_name)
|
|
||||||
|
|
||||||
# check if all dependencies actually exist
|
|
||||||
for gen_name, dependencies in graph.items():
|
|
||||||
for dep in dependencies:
|
|
||||||
if dep not in graph:
|
|
||||||
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
# process generators in topological order
|
|
||||||
sorter = TopologicalSorter(graph)
|
|
||||||
for generator_name in sorter.static_order():
|
|
||||||
assert generator_name is not None
|
assert generator_name is not None
|
||||||
machine_updated |= execute_generator(
|
machine_updated |= execute_generator(
|
||||||
machine=machine,
|
machine=machine,
|
||||||
generator_name=generator_name,
|
generator_name=generator_name,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
secret_vars_store=secret_vars_store,
|
secret_vars_store=machine.secret_vars_store,
|
||||||
public_vars_store=public_vars_store,
|
public_vars_store=machine.public_vars_store,
|
||||||
)
|
)
|
||||||
if machine_updated:
|
if machine_updated:
|
||||||
# flush caches to make sure the new secrets are available in evaluation
|
# flush caches to make sure the new secrets are available in evaluation
|
||||||
|
|||||||
@@ -30,8 +30,10 @@ def get_vars(machine: Machine) -> list[Var]:
|
|||||||
return pub_store.get_all() + sec_store.get_all()
|
return pub_store.get_all() + sec_store.get_all()
|
||||||
|
|
||||||
|
|
||||||
def _get_prompt_value(
|
def _get_previous_value(
|
||||||
machine: Machine, generator: Generator, prompt: Prompt
|
machine: Machine,
|
||||||
|
generator: Generator,
|
||||||
|
prompt: Prompt,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
if not prompt.has_file:
|
if not prompt.has_file:
|
||||||
return None
|
return None
|
||||||
@@ -40,10 +42,16 @@ def _get_prompt_value(
|
|||||||
return pub_store.get(
|
return pub_store.get(
|
||||||
generator.name, prompt.name, shared=generator.share
|
generator.name, prompt.name, shared=generator.share
|
||||||
).decode()
|
).decode()
|
||||||
|
sec_store = secret_store(machine)
|
||||||
|
if sec_store.exists(generator.name, prompt.name, shared=generator.share):
|
||||||
|
return sec_store.get(
|
||||||
|
generator.name, prompt.name, shared=generator.share
|
||||||
|
).decode()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@API.register
|
@API.register
|
||||||
|
# TODO: use machine_name
|
||||||
def get_prompts(machine: Machine) -> list[Generator]:
|
def get_prompts(machine: Machine) -> list[Generator]:
|
||||||
generators = []
|
generators = []
|
||||||
for gen_name, generator in machine.vars_generators.items():
|
for gen_name, generator in machine.vars_generators.items():
|
||||||
@@ -61,7 +69,7 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
|||||||
has_file=prompt["createFile"],
|
has_file=prompt["createFile"],
|
||||||
generator=gen_name,
|
generator=gen_name,
|
||||||
)
|
)
|
||||||
prompt.previous_value = _get_prompt_value(machine, gen, prompt)
|
prompt.previous_value = _get_previous_value(machine, gen, prompt)
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
generators.append(gen)
|
generators.append(gen)
|
||||||
@@ -69,6 +77,8 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
|||||||
|
|
||||||
|
|
||||||
# TODO: Ensure generator dependencies are met (executed in correct order etc.)
|
# TODO: Ensure generator dependencies are met (executed in correct order etc.)
|
||||||
|
# TODO: for missing prompts, default to existing values
|
||||||
|
# TODO: raise error if mandatory prompt not provided
|
||||||
@API.register
|
@API.register
|
||||||
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
|
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
|
||||||
for update in updates:
|
for update in updates:
|
||||||
|
|||||||
@@ -27,12 +27,12 @@ def test_get_subgraph() -> None:
|
|||||||
"c": set(),
|
"c": set(),
|
||||||
"d": set(),
|
"d": set(),
|
||||||
}
|
}
|
||||||
assert _get_subgraph(graph, "a") == {
|
assert _get_subgraph(graph, ["a"]) == {
|
||||||
"a": {"b", "c"},
|
"a": {"b", "c"},
|
||||||
"b": {"c"},
|
"b": {"c"},
|
||||||
"c": set(),
|
"c": set(),
|
||||||
}
|
}
|
||||||
assert _get_subgraph(graph, "b") == {"b": {"c"}, "c": set()}
|
assert _get_subgraph(graph, ["b"]) == {"b": {"c"}, "c": set()}
|
||||||
|
|
||||||
|
|
||||||
def test_dependencies_as_files() -> None:
|
def test_dependencies_as_files() -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user