vars: refactor - move generator specific code to Generator class

Several functions in generate.py were specific to generator instances. Let's move them into the Generator class
This commit is contained in:
DavHau
2025-08-19 14:22:00 +07:00
parent 0a1a63dfdd
commit aaca8f4763
3 changed files with 255 additions and 244 deletions

View File

@@ -1,12 +1,6 @@
import argparse
import logging
import os
import shutil
import sys
from collections.abc import Callable
from contextlib import ExitStack
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Literal
from clan_cli.completions import (
@@ -14,251 +8,20 @@ from clan_cli.completions import (
complete_machines,
complete_services_for_machine,
)
from clan_cli.vars._types import StoreBase
from clan_cli.vars.generator import Generator, GeneratorKey
from clan_cli.vars.migration import check_can_migrate, migrate_files
from clan_lib.api import API
from clan_lib.cmd import RunOpts, run
from clan_lib.errors import ClanError
from clan_lib.flake import require_flake
from clan_lib.git import commit_files
from clan_lib.machines.list import list_full_machines
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_config, nix_shell, nix_test_store
from clan_lib.nix import nix_config
from .graph import minimal_closure, requested_closure
from .prompt import ask
log = logging.getLogger(__name__)
def bubblewrap_cmd(generator: str, tmpdir: Path) -> list[str]:
test_store = nix_test_store()
real_bash_path = Path("bash")
if os.environ.get("IN_NIX_SANDBOX"):
bash_executable_path = Path(str(shutil.which("bash")))
real_bash_path = bash_executable_path.resolve()
# fmt: off
return nix_shell(
[
"bash",
"bubblewrap",
],
[
"bwrap",
"--unshare-all",
"--tmpfs", "/",
"--ro-bind", "/nix/store", "/nix/store",
"--ro-bind", "/bin/sh", "/bin/sh",
*(["--ro-bind", str(test_store), str(test_store)] if test_store else []),
"--dev", "/dev",
# not allowed to bind procfs in some sandboxes
"--bind", str(tmpdir), str(tmpdir),
"--chdir", "/",
# Doesn't work in our CI?
#"--proc", "/proc",
#"--hostname", "facts",
"--bind", "/proc", "/proc",
"--uid", "1000",
"--gid", "1000",
"--",
str(real_bash_path), "-c", generator
]
)
# fmt: on
# TODO: implement caching to not decrypt the same secret multiple times
def decrypt_dependencies(
machine: "Machine",
generator: Generator,
secret_vars_store: StoreBase,
public_vars_store: StoreBase,
) -> dict[str, dict[str, bytes]]:
generators = Generator.get_machine_generators(machine.name, machine.flake)
result: dict[str, dict[str, bytes]] = {}
for dep_key in set(generator.dependencies):
# For now, we only support dependencies from the same machine
if dep_key.machine != machine.name:
msg = f"Cross-machine dependencies are not supported. Generator {generator.name} depends on {dep_key.name} from machine {dep_key.machine}"
raise ClanError(msg)
result[dep_key.name] = {}
dep_generator = next((g for g in generators if g.name == dep_key.name), None)
if dep_generator is None:
msg = f"Generator {dep_key.name} not found in machine {machine.name}"
raise ClanError(msg)
dep_files = dep_generator.files
for file in dep_files:
if file.secret:
result[dep_key.name][file.name] = secret_vars_store.get(
dep_generator, file.name
)
else:
result[dep_key.name][file.name] = public_vars_store.get(
dep_generator, file.name
)
return result
# decrypt dependencies and return temporary file tree
def dependencies_as_dir(
decrypted_dependencies: dict[str, dict[str, bytes]],
tmpdir: Path,
) -> None:
for dep_generator, files in decrypted_dependencies.items():
dep_generator_dir = tmpdir / dep_generator
# Explicitly specify parents and exist_ok default values for clarity
dep_generator_dir.mkdir(mode=0o700, parents=False, exist_ok=False)
for file_name, file in files.items():
file_path = dep_generator_dir / file_name
# Avoid the file creation and chmod race
# If the file already existed,
# we'd have to create a temp one and rename instead;
# however, this is a clean dir so there shouldn't be any collisions
file_path.touch(mode=0o600, exist_ok=False)
file_path.write_bytes(file)
def _execute_generator(
machine: "Machine",
generator: Generator,
secret_vars_store: StoreBase,
public_vars_store: StoreBase,
prompt_values: dict[str, str],
no_sandbox: bool = False,
) -> None:
if not isinstance(machine.flake, Path):
msg = f"flake is not a Path: {machine.flake}"
msg += "fact/secret generation is only supported for local flakes"
# build temporary file tree of dependencies
decrypted_dependencies = decrypt_dependencies(
machine,
generator,
secret_vars_store,
public_vars_store,
)
def get_prompt_value(prompt_name: str) -> str:
try:
return prompt_values[prompt_name]
except KeyError as e:
msg = f"prompt value for '{prompt_name}' in generator {generator.name} not provided"
raise ClanError(msg) from e
env = os.environ.copy()
with ExitStack() as stack:
_tmpdir = stack.enter_context(TemporaryDirectory(prefix="vars-"))
tmpdir = Path(_tmpdir).resolve()
tmpdir_in = tmpdir / "in"
tmpdir_prompts = tmpdir / "prompts"
tmpdir_out = tmpdir / "out"
tmpdir_in.mkdir()
tmpdir_out.mkdir()
env["in"] = str(tmpdir_in)
env["out"] = str(tmpdir_out)
# populate dependency inputs
dependencies_as_dir(decrypted_dependencies, tmpdir_in)
# populate prompted values
# TODO: make prompts rest API friendly
if generator.prompts:
tmpdir_prompts.mkdir()
env["prompts"] = str(tmpdir_prompts)
for prompt in generator.prompts:
prompt_file = tmpdir_prompts / prompt.name
value = get_prompt_value(prompt.name)
prompt_file.write_text(value)
from clan_lib import bwrap
final_script = generator.final_script()
if sys.platform == "linux" and bwrap.bubblewrap_works():
cmd = bubblewrap_cmd(str(final_script), tmpdir)
elif sys.platform == "darwin":
from clan_lib.sandbox_exec import sandbox_exec_cmd
cmd = stack.enter_context(sandbox_exec_cmd(str(final_script), tmpdir))
else:
# For non-sandboxed execution (Linux without bubblewrap or other platforms)
if not no_sandbox:
msg = (
f"Cannot safely execute generator {generator.name}: Sandboxing is not available on this system\n"
f"Re-run 'vars generate' with '--no-sandbox' to disable sandboxing"
)
raise ClanError(msg)
cmd = ["bash", "-c", str(final_script)]
run(cmd, RunOpts(env=env, cwd=tmpdir))
files_to_commit = []
# store secrets
files = generator.files
public_changed = False
secret_changed = False
for file in files:
secret_file = tmpdir_out / file.name
if not secret_file.is_file():
msg = f"did not generate a file for '{file.name}' when running the following command:\n"
msg += str(final_script)
# list all files in the output directory
if tmpdir_out.is_dir():
msg += "\nOutput files:\n"
for f in tmpdir_out.iterdir():
msg += f" - {f.name}\n"
raise ClanError(msg)
if file.secret:
file_path = secret_vars_store.set(
generator,
file,
secret_file.read_bytes(),
)
secret_changed = True
else:
file_path = public_vars_store.set(
generator,
file,
secret_file.read_bytes(),
)
public_changed = True
if file_path:
files_to_commit.append(file_path)
validation = generator.validation()
if validation is not None:
if public_changed:
files_to_commit.append(
public_vars_store.set_validation(generator, validation)
)
if secret_changed:
files_to_commit.append(
secret_vars_store.set_validation(generator, validation)
)
commit_files(
files_to_commit,
machine.flake_dir,
f"Update vars via generator {generator.name} for machine {machine.name}",
)
def _ask_prompts(
generator: Generator,
) -> dict[str, str]:
prompt_values: dict[str, str] = {}
for prompt in generator.prompts:
var_id = f"{generator.name}/{prompt.name}"
prompt_values[prompt.name] = ask(
var_id,
prompt.prompt_type,
prompt.description if prompt.description != prompt.name else None,
)
return prompt_values
@API.register
def get_generators(
machine: Machine,
@@ -343,11 +106,8 @@ def _generate_vars_for_machine(
if check_can_migrate(machine, generator):
migrate_files(machine, generator)
else:
_execute_generator(
generator.execute(
machine=machine,
generator=generator,
secret_vars_store=machine.secret_vars_store,
public_vars_store=machine.public_vars_store,
prompt_values=prompt_values.get(generator.name, {}),
no_sandbox=no_sandbox,
)
@@ -368,7 +128,7 @@ def run_generators(
generators: GeneratorKey
| list[GeneratorKey]
| Literal["all", "minimal"] = "minimal",
prompt_values: dict[str, dict[str, str]] | PromptFunc = _ask_prompts,
prompt_values: dict[str, dict[str, str]] | PromptFunc = lambda g: g.ask_prompts(),
no_sandbox: bool = False,
) -> None:
"""Run the specified generators for a machine.