vars: refactor - ask prompts before running any generators

This commit is contained in:
DavHau
2024-09-06 16:07:20 +02:00
parent 675e4c5931
commit f5e65b5d0e
3 changed files with 37 additions and 23 deletions

View File

@@ -20,7 +20,7 @@ from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell
from .check import check_vars
from .prompt import prompt
from .prompt import ask
from .public_modules import FactStoreBase
from .secret_modules import SecretStoreBase
@@ -98,8 +98,9 @@ def execute_generator(
regenerate: bool,
secret_vars_store: SecretStoreBase,
public_vars_store: FactStoreBase,
prompt_values: dict[str, str] | None = None,
prompt_values: dict[str, str] | None,
) -> bool:
prompt_values = {} if prompt_values is None else prompt_values
# check if all secrets exist and generate them if at least one is missing
needs_regeneration = not check_vars(machine, generator_name=generator_name)
log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}")
@@ -118,17 +119,11 @@ def execute_generator(
)
def get_prompt_value(prompt_name: str) -> str:
if prompt_values:
try:
return prompt_values[prompt_name]
except KeyError as e:
msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided"
raise ClanError(msg) from e
description = machine.vars_generators[generator_name]["prompts"][prompt_name][
"description"
]
_type = machine.vars_generators[generator_name]["prompts"][prompt_name]["type"]
return prompt(description, _type)
try:
return prompt_values[prompt_name]
except KeyError as e:
msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided"
raise ClanError(msg) from e
env = os.environ.copy()
with TemporaryDirectory() as tmp:
@@ -265,14 +260,20 @@ def _required_generators(
return list(sorter.static_order())
def _generate_vars_for_machine(
def _ask_prompts(
machine: Machine,
generator_name: str | None,
regenerate: bool,
) -> bool:
return _generate_vars_for_machine_multi(
machine, [generator_name] if generator_name else [], regenerate
)
generator_names: list[str],
) -> dict[str, dict[str, str]]:
prompt_values: dict[str, dict[str, str]] = {}
for generator in generator_names:
prompts = machine.vars_generators[generator]["prompts"]
for prompt_name, _prompt in prompts.items():
if generator not in prompt_values:
prompt_values[generator] = {}
prompt_values[generator][prompt_name] = ask(
_prompt["description"], _prompt["type"]
)
return prompt_values
def _generate_vars_for_machine_multi(
@@ -291,6 +292,9 @@ def _generate_vars_for_machine_multi(
regenerate=regenerate,
secret_vars_store=machine.secret_vars_store,
public_vars_store=machine.public_vars_store,
prompt_values=_ask_prompts(machine, [generator_name]).get(
generator_name, {}
),
)
if machine_updated:
# flush caches to make sure the new secrets are available in evaluation
@@ -298,6 +302,16 @@ def _generate_vars_for_machine_multi(
return machine_updated
def _generate_vars_for_machine(
machine: Machine,
generator_name: str | None,
regenerate: bool,
) -> bool:
return _generate_vars_for_machine_multi(
machine, [generator_name] if generator_name else [], regenerate
)
def generate_vars(
machines: list[Machine],
generator_name: str | None,

View File

@@ -7,7 +7,7 @@ from clan_cli.errors import ClanError
log = logging.getLogger(__name__)
def prompt(description: str, input_type: str) -> str:
def ask(description: str, input_type: str) -> str:
if input_type == "line":
result = input(f"Enter the value for {description}: ")
elif input_type == "multiline":

View File

@@ -7,7 +7,7 @@ from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.machines.machines import Machine
from clan_cli.vars.get import get_var
from .prompt import prompt
from .prompt import ask
log = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ def set_command(machine: str, var_id: str, flake: FlakeId) -> None:
_machine = Machine(name=machine, flake=flake)
var = get_var(_machine, var_id)
if sys.stdin.isatty():
new_value = prompt(var.id, "hidden").encode("utf-8")
new_value = ask(var.id, "hidden").encode("utf-8")
else:
new_value = sys.stdin.buffer.read()
var.set(new_value)