vars: implement prompts

This commit is contained in:
DavHau
2024-07-22 15:31:55 +07:00
parent a2c4130ebe
commit 0acf9178c8
4 changed files with 115 additions and 85 deletions

View File

@@ -39,7 +39,7 @@ in
vars = { vars = {
generators = lib.flip lib.mapAttrs config.clan.core.vars.generators ( generators = lib.flip lib.mapAttrs config.clan.core.vars.generators (
_name: generator: { _name: generator: {
inherit (generator) dependencies finalScript; inherit (generator) dependencies finalScript prompts;
files = lib.flip lib.mapAttrs generator.files (_name: file: { inherit (file) secret; }); files = lib.flip lib.mapAttrs generator.files (_name: file: { inherit (file) secret; });
} }
); );

View File

@@ -108,8 +108,9 @@ in
Prompts are available to the generator script as files. Prompts are available to the generator script as files.
For example, a prompt named 'prompt1' will be available via $prompts/prompt1 For example, a prompt named 'prompt1' will be available via $prompts/prompt1
''; '';
default = { };
type = attrsOf (submodule { type = attrsOf (submodule {
options = { options = options {
description = { description = {
description = '' description = ''
The description of the prompted value The description of the prompted value

View File

@@ -2,9 +2,8 @@ import argparse
import importlib import importlib
import logging import logging
import os import os
import subprocess
import sys import sys
from collections.abc import Callable from getpass import getpass
from graphlib import TopologicalSorter from graphlib import TopologicalSorter
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
@@ -29,17 +28,7 @@ from .secret_modules import SecretStoreBase
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def read_multiline_input(prompt: str = "Finish with Ctrl-D") -> str: def bubblewrap_cmd(generator: str, tmpdir: Path) -> list[str]:
"""
Read multi-line input from stdin.
"""
print(prompt, flush=True)
proc = subprocess.run(["cat"], stdout=subprocess.PIPE, text=True)
log.info("Input received. Processing...")
return proc.stdout
def bubblewrap_cmd(generator: str, generator_dir: Path, dep_tmpdir: Path) -> list[str]:
# fmt: off # fmt: off
return nix_shell( return nix_shell(
[ [
@@ -51,8 +40,7 @@ def bubblewrap_cmd(generator: str, generator_dir: Path, dep_tmpdir: Path) -> lis
"--ro-bind", "/nix/store", "/nix/store", "--ro-bind", "/nix/store", "/nix/store",
"--tmpfs", "/usr/lib/systemd", "--tmpfs", "/usr/lib/systemd",
"--dev", "/dev", "--dev", "/dev",
"--bind", str(generator_dir), str(generator_dir), "--bind", str(tmpdir), str(tmpdir),
"--ro-bind", str(dep_tmpdir), str(dep_tmpdir),
"--unshare-all", "--unshare-all",
"--unshare-user", "--unshare-user",
"--uid", "1000", "--uid", "1000",
@@ -92,7 +80,7 @@ def decrypt_dependencies(
def dependencies_as_dir( def dependencies_as_dir(
decrypted_dependencies: dict[str, dict[str, bytes]], decrypted_dependencies: dict[str, dict[str, bytes]],
tmpdir: Path, tmpdir: Path,
) -> Path: ) -> None:
for dep_generator, files in decrypted_dependencies.items(): for dep_generator, files in decrypted_dependencies.items():
dep_generator_dir = tmpdir / dep_generator dep_generator_dir = tmpdir / dep_generator
dep_generator_dir.mkdir() dep_generator_dir.mkdir()
@@ -102,7 +90,6 @@ def dependencies_as_dir(
file_path.touch() file_path.touch()
file_path.chmod(0o600) file_path.chmod(0o600)
file_path.write_bytes(file) file_path.write_bytes(file)
return tmpdir
def execute_generator( def execute_generator(
@@ -111,10 +98,7 @@ def execute_generator(
regenerate: bool, regenerate: bool,
secret_vars_store: SecretStoreBase, secret_vars_store: SecretStoreBase,
public_vars_store: FactStoreBase, public_vars_store: FactStoreBase,
dep_tmpdir: Path,
prompt: Callable[[str], str],
) -> bool: ) -> bool:
generator_dir = dep_tmpdir / generator_name
# check if all secrets exist and generate them if at least one is missing # check if all secrets exist and generate them if at least one is missing
needs_regeneration = not check_secrets(machine, generator_name=generator_name) needs_regeneration = not check_secrets(machine, generator_name=generator_name)
log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}") log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}")
@@ -124,51 +108,65 @@ def execute_generator(
msg = f"flake is not a Path: {machine.flake}" msg = f"flake is not a Path: {machine.flake}"
msg += "fact/secret generation is only supported for local flakes" msg += "fact/secret generation is only supported for local flakes"
# compatibility for old outputs.nix users
generator = machine.vars_generators[generator_name]["finalScript"] generator = machine.vars_generators[generator_name]["finalScript"]
# if machine.vars_data[generator_name]["generator"]["prompt"]:
# prompt_value = prompt(machine.vars_data[generator_name]["generator"]["prompt"])
# env["prompt_value"] = prompt_value
# build temporary file tree of dependencies # build temporary file tree of dependencies
decrypted_dependencies = decrypt_dependencies( decrypted_dependencies = decrypt_dependencies(
machine, generator_name, secret_vars_store, public_vars_store machine, generator_name, secret_vars_store, public_vars_store
) )
env = os.environ.copy() env = os.environ.copy()
generator_dir.mkdir(parents=True)
env["out"] = str(generator_dir)
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
dep_tmpdir = dependencies_as_dir(decrypted_dependencies, Path(tmp)) tmpdir = Path(tmp)
env["in"] = str(dep_tmpdir) tmpdir_in = tmpdir / "in"
tmpdir_prompts = tmpdir / "prompts"
tmpdir_out = tmpdir / "out"
tmpdir_in.mkdir()
tmpdir_out.mkdir()
env["in"] = str(tmpdir_in)
env["out"] = str(tmpdir_out)
# populate dependency inputs
dependencies_as_dir(decrypted_dependencies, tmpdir_in)
# populate prompted values
# TODO: make prompts rest API friendly
if machine.vars_generators[generator_name]["prompts"]:
tmpdir_prompts.mkdir()
env["prompts"] = str(tmpdir_prompts)
for prompt_name, prompt in machine.vars_generators[generator_name][
"prompts"
].items():
prompt_file = tmpdir_prompts / prompt_name
value = prompt_func(prompt["description"], prompt["type"])
prompt_file.write_text(value)
if sys.platform == "linux": if sys.platform == "linux":
cmd = bubblewrap_cmd(generator, generator_dir, dep_tmpdir=dep_tmpdir) cmd = bubblewrap_cmd(generator, tmpdir)
else: else:
cmd = ["bash", "-c", generator] cmd = ["bash", "-c", generator]
run( run(
cmd, cmd,
env=env, env=env,
) )
files_to_commit = [] files_to_commit = []
# store secrets # store secrets
files = machine.vars_generators[generator_name]["files"] files = machine.vars_generators[generator_name]["files"]
for file_name, file in files.items(): for file_name, file in files.items():
groups = machine.deployment["sops"]["defaultGroups"] groups = machine.deployment["sops"]["defaultGroups"]
secret_file = generator_dir / file_name secret_file = tmpdir_out / file_name
if not secret_file.is_file(): if not secret_file.is_file():
msg = f"did not generate a file for '{file_name}' when running the following command:\n" msg = f"did not generate a file for '{file_name}' when running the following command:\n"
msg += generator msg += generator
raise ClanError(msg) raise ClanError(msg)
if file["secret"]: if file["secret"]:
file_path = secret_vars_store.set( file_path = secret_vars_store.set(
generator_name, file_name, secret_file.read_bytes(), groups generator_name, file_name, secret_file.read_bytes(), groups
) )
else: else:
file_path = public_vars_store.set( file_path = public_vars_store.set(
generator_name, file_name, secret_file.read_bytes() generator_name, file_name, secret_file.read_bytes()
) )
if file_path: if file_path:
files_to_commit.append(file_path) files_to_commit.append(file_path)
commit_files( commit_files(
files_to_commit, files_to_commit,
machine.flake_dir, machine.flake_dir,
@@ -177,9 +175,18 @@ def execute_generator(
return True return True
def prompt_func(text: str) -> str: def prompt_func(description: str, input_type: str) -> str:
print(f"{text}: ") if input_type == "line":
return read_multiline_input() result = input(f"Enter the value for {description}: ")
elif input_type == "multiline":
print(f"Enter the value for {description} (Finish with Ctrl-D): ")
result = sys.stdin.read()
elif input_type == "hidden":
result = getpass(f"Enter the value for {description} (hidden): ")
else:
raise ClanError(f"Unknown input type: {input_type} for prompt {description}")
log.info("Input received. Processing...")
return result
def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]: def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
@@ -197,11 +204,7 @@ def _generate_vars_for_machine(
machine: Machine, machine: Machine,
generator_name: str | None, generator_name: str | None,
regenerate: bool, regenerate: bool,
tmpdir: Path,
prompt: Callable[[str], str] = prompt_func,
) -> bool: ) -> bool:
local_temp = tmpdir / machine.name
local_temp.mkdir()
secret_vars_module = importlib.import_module(machine.secret_vars_module) secret_vars_module = importlib.import_module(machine.secret_vars_module)
secret_vars_store = secret_vars_module.SecretStore(machine=machine) secret_vars_store = secret_vars_module.SecretStore(machine=machine)
@@ -216,13 +219,6 @@ def _generate_vars_for_machine(
f"Could not find generator with name: {generator_name}. The following generators are available: {generators}" f"Could not find generator with name: {generator_name}. The following generators are available: {generators}"
) )
# if generator_name:
# machine_generator_facts = {
# generator_name: machine.vars_generators[generator_name]
# }
# else:
# machine_generator_facts = machine.vars_generators
graph = { graph = {
gen_name: set(generator["dependencies"]) gen_name: set(generator["dependencies"])
for gen_name, generator in machine.vars_generators.items() for gen_name, generator in machine.vars_generators.items()
@@ -250,8 +246,6 @@ def _generate_vars_for_machine(
regenerate=regenerate, regenerate=regenerate,
secret_vars_store=secret_vars_store, secret_vars_store=secret_vars_store,
public_vars_store=public_vars_store, public_vars_store=public_vars_store,
dep_tmpdir=local_temp,
prompt=prompt,
) )
if machine_updated: if machine_updated:
# flush caches to make sure the new secrets are available in evaluation # flush caches to make sure the new secrets are available in evaluation
@@ -263,25 +257,21 @@ def generate_vars(
machines: list[Machine], machines: list[Machine],
generator_name: str | None, generator_name: str | None,
regenerate: bool, regenerate: bool,
prompt: Callable[[str], str] = prompt_func,
) -> bool: ) -> bool:
was_regenerated = False was_regenerated = False
with TemporaryDirectory() as tmp: for machine in machines:
tmpdir = Path(tmp) errors = 0
try:
for machine in machines: was_regenerated |= _generate_vars_for_machine(
errors = 0 machine, generator_name, regenerate
try: )
was_regenerated |= _generate_vars_for_machine( except Exception as exc:
machine, generator_name, regenerate, tmpdir, prompt log.error(f"Failed to generate facts for {machine.name}: {exc}")
) errors += 1
except Exception as exc: if errors > 0:
log.error(f"Failed to generate facts for {machine.name}: {exc}") raise ClanError(
errors += 1 f"Failed to generate facts for {errors} hosts. Check the logs above"
if errors > 0: )
raise ClanError(
f"Failed to generate facts for {errors} hosts. Check the logs above"
)
if not was_regenerated: if not was_regenerated:
print("All secrets and facts are already up to date") print("All secrets and facts are already up to date")

View File

@@ -1,6 +1,7 @@
import os import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from io import StringIO
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any from typing import Any
@@ -55,7 +56,8 @@ def test_dependencies_as_files() -> None:
), ),
) )
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
dep_tmpdir = dependencies_as_dir(decrypted_dependencies, Path(tmpdir)) dep_tmpdir = Path(tmpdir)
dependencies_as_dir(decrypted_dependencies, dep_tmpdir)
assert dep_tmpdir.is_dir() assert dep_tmpdir.is_dir()
assert (dep_tmpdir / "gen_1" / "var_1a").read_bytes() == b"var_1a" assert (dep_tmpdir / "gen_1" / "var_1a").read_bytes() == b"var_1a"
assert (dep_tmpdir / "gen_1" / "var_1b").read_bytes() == b"var_1b" assert (dep_tmpdir / "gen_1" / "var_1b").read_bytes() == b"var_1b"
@@ -232,3 +234,40 @@ def test_dependant_generators(
) )
assert child_file_path.is_file() assert child_file_path.is_file()
assert child_file_path.read_text() == "hello\n" assert child_file_path.read_text() == "hello\n"
@pytest.mark.impure
@pytest.mark.parametrize(
("prompt_type", "input_value"),
[
("line", "my input"),
("multiline", "my\nmultiline\ninput\n"),
# The hidden type cannot easily be tested, as getpass() reads from /dev/tty directly
# ("hidden", "my hidden input"),
],
)
def test_prompt(
monkeypatch: pytest.MonkeyPatch,
temporary_home: Path,
prompt_type: str,
input_value: str,
) -> None:
config = nested_dict()
my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"]
my_generator["files"]["my_value"]["secret"] = False
my_generator["prompts"]["prompt1"]["description"] = "dream2nix"
my_generator["prompts"]["prompt1"]["type"] = prompt_type
my_generator["script"] = "cat $prompts/prompt1 > $out/my_value"
flake = generate_flake(
temporary_home,
flake_template=CLAN_CORE / "templates" / "minimal",
machine_configs=dict(my_machine=config),
)
monkeypatch.chdir(flake.path)
monkeypatch.setattr("sys.stdin", StringIO(input_value))
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
var_file_path = (
flake.path / "machines" / "my_machine" / "vars" / "my_generator" / "my_value"
)
assert var_file_path.is_file()
assert var_file_path.read_text() == input_value