From f5e65b5d0e4d3bbe04620484e38883a4d56d8446 Mon Sep 17 00:00:00 2001 From: DavHau Date: Fri, 6 Sep 2024 16:07:20 +0200 Subject: [PATCH] vars: refactor - ask prompts before running any generators --- pkgs/clan-cli/clan_cli/vars/generate.py | 54 ++++++++++++++++--------- pkgs/clan-cli/clan_cli/vars/prompt.py | 2 +- pkgs/clan-cli/clan_cli/vars/set.py | 4 +- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/vars/generate.py b/pkgs/clan-cli/clan_cli/vars/generate.py index cd04f8a1e..f85d9f29b 100644 --- a/pkgs/clan-cli/clan_cli/vars/generate.py +++ b/pkgs/clan-cli/clan_cli/vars/generate.py @@ -20,7 +20,7 @@ from clan_cli.machines.machines import Machine from clan_cli.nix import nix_shell from .check import check_vars -from .prompt import prompt +from .prompt import ask from .public_modules import FactStoreBase from .secret_modules import SecretStoreBase @@ -98,8 +98,9 @@ def execute_generator( regenerate: bool, secret_vars_store: SecretStoreBase, public_vars_store: FactStoreBase, - prompt_values: dict[str, str] | None = None, + prompt_values: dict[str, str] | None, ) -> bool: + prompt_values = {} if prompt_values is None else prompt_values # check if all secrets exist and generate them if at least one is missing needs_regeneration = not check_vars(machine, generator_name=generator_name) log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}") @@ -118,17 +119,11 @@ def execute_generator( ) def get_prompt_value(prompt_name: str) -> str: - if prompt_values: - try: - return prompt_values[prompt_name] - except KeyError as e: - msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided" - raise ClanError(msg) from e - description = machine.vars_generators[generator_name]["prompts"][prompt_name][ - "description" - ] - _type = machine.vars_generators[generator_name]["prompts"][prompt_name]["type"] - return prompt(description, _type) + try: + return prompt_values[prompt_name] + except KeyError as e: + msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided" + raise ClanError(msg) from e env = os.environ.copy() with TemporaryDirectory() as tmp: @@ -265,14 +260,20 @@ def _required_generators( return list(sorter.static_order()) -def _generate_vars_for_machine( +def _ask_prompts( machine: Machine, - generator_name: str | None, - regenerate: bool, -) -> bool: - return _generate_vars_for_machine_multi( - machine, [generator_name] if generator_name else [], regenerate - ) + generator_names: list[str], +) -> dict[str, dict[str, str]]: + prompt_values: dict[str, dict[str, str]] = {} + for generator in generator_names: + prompts = machine.vars_generators[generator]["prompts"] + for prompt_name, _prompt in prompts.items(): + if generator not in prompt_values: + prompt_values[generator] = {} + prompt_values[generator][prompt_name] = ask( + _prompt["description"], _prompt["type"] + ) + return prompt_values def _generate_vars_for_machine_multi( @@ -291,6 +292,9 @@ def _generate_vars_for_machine_multi( regenerate=regenerate, secret_vars_store=machine.secret_vars_store, public_vars_store=machine.public_vars_store, + prompt_values=_ask_prompts(machine, [generator_name]).get( + generator_name, {} + ), ) if machine_updated: # flush caches to make sure the new secrets are available in evaluation @@ -298,6 +302,16 @@ def _generate_vars_for_machine_multi( return machine_updated +def _generate_vars_for_machine( + machine: Machine, + generator_name: str | None, + regenerate: bool, +) -> bool: + return _generate_vars_for_machine_multi( + machine, [generator_name] if generator_name else [], regenerate + ) + + def generate_vars( machines: list[Machine], generator_name: str | None, diff --git a/pkgs/clan-cli/clan_cli/vars/prompt.py b/pkgs/clan-cli/clan_cli/vars/prompt.py index 4ac1aebe4..bbca8e55c 100644 --- a/pkgs/clan-cli/clan_cli/vars/prompt.py +++ b/pkgs/clan-cli/clan_cli/vars/prompt.py @@ -7,7 +7,7 @@ from clan_cli.errors import ClanError log = logging.getLogger(__name__) -def prompt(description: str, input_type: str) -> str: +def ask(description: str, input_type: str) -> str: if input_type == "line": result = input(f"Enter the value for {description}: ") elif input_type == "multiline": diff --git a/pkgs/clan-cli/clan_cli/vars/set.py b/pkgs/clan-cli/clan_cli/vars/set.py index bf16adb80..bc21913d9 100644 --- a/pkgs/clan-cli/clan_cli/vars/set.py +++ b/pkgs/clan-cli/clan_cli/vars/set.py @@ -7,7 +7,7 @@ from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.machines.machines import Machine from clan_cli.vars.get import get_var -from .prompt import prompt +from .prompt import ask log = logging.getLogger(__name__) @@ -16,7 +16,7 @@ def set_command(machine: str, var_id: str, flake: FlakeId) -> None: _machine = Machine(name=machine, flake=flake) var = get_var(_machine, var_id) if sys.stdin.isatty(): - new_value = prompt(var.id, "hidden").encode("utf-8") + new_value = ask(var.id, "hidden").encode("utf-8") else: new_value = sys.stdin.buffer.read() var.set(new_value)