vars: refactor - ask prompts before running any generators

This commit is contained in:
DavHau
2024-09-06 16:07:20 +02:00
parent 675e4c5931
commit f5e65b5d0e
3 changed files with 37 additions and 23 deletions

View File

@@ -20,7 +20,7 @@ from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell from clan_cli.nix import nix_shell
from .check import check_vars from .check import check_vars
from .prompt import prompt from .prompt import ask
from .public_modules import FactStoreBase from .public_modules import FactStoreBase
from .secret_modules import SecretStoreBase from .secret_modules import SecretStoreBase
@@ -98,8 +98,9 @@ def execute_generator(
regenerate: bool, regenerate: bool,
secret_vars_store: SecretStoreBase, secret_vars_store: SecretStoreBase,
public_vars_store: FactStoreBase, public_vars_store: FactStoreBase,
prompt_values: dict[str, str] | None = None, prompt_values: dict[str, str] | None,
) -> bool: ) -> bool:
prompt_values = {} if prompt_values is None else prompt_values
# check if all secrets exist and generate them if at least one is missing # check if all secrets exist and generate them if at least one is missing
needs_regeneration = not check_vars(machine, generator_name=generator_name) needs_regeneration = not check_vars(machine, generator_name=generator_name)
log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}") log.debug(f"{generator_name} needs_regeneration: {needs_regeneration}")
@@ -118,17 +119,11 @@ def execute_generator(
) )
def get_prompt_value(prompt_name: str) -> str: def get_prompt_value(prompt_name: str) -> str:
if prompt_values: try:
try: return prompt_values[prompt_name]
return prompt_values[prompt_name] except KeyError as e:
except KeyError as e: msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided"
msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided" raise ClanError(msg) from e
raise ClanError(msg) from e
description = machine.vars_generators[generator_name]["prompts"][prompt_name][
"description"
]
_type = machine.vars_generators[generator_name]["prompts"][prompt_name]["type"]
return prompt(description, _type)
env = os.environ.copy() env = os.environ.copy()
with TemporaryDirectory() as tmp: with TemporaryDirectory() as tmp:
@@ -265,14 +260,20 @@ def _required_generators(
return list(sorter.static_order()) return list(sorter.static_order())
def _generate_vars_for_machine( def _ask_prompts(
machine: Machine, machine: Machine,
generator_name: str | None, generator_names: list[str],
regenerate: bool, ) -> dict[str, dict[str, str]]:
) -> bool: prompt_values: dict[str, dict[str, str]] = {}
return _generate_vars_for_machine_multi( for generator in generator_names:
machine, [generator_name] if generator_name else [], regenerate prompts = machine.vars_generators[generator]["prompts"]
) for prompt_name, _prompt in prompts.items():
if generator not in prompt_values:
prompt_values[generator] = {}
prompt_values[generator][prompt_name] = ask(
_prompt["description"], _prompt["type"]
)
return prompt_values
def _generate_vars_for_machine_multi( def _generate_vars_for_machine_multi(
@@ -291,6 +292,9 @@ def _generate_vars_for_machine_multi(
regenerate=regenerate, regenerate=regenerate,
secret_vars_store=machine.secret_vars_store, secret_vars_store=machine.secret_vars_store,
public_vars_store=machine.public_vars_store, public_vars_store=machine.public_vars_store,
prompt_values=_ask_prompts(machine, [generator_name]).get(
generator_name, {}
),
) )
if machine_updated: if machine_updated:
# flush caches to make sure the new secrets are available in evaluation # flush caches to make sure the new secrets are available in evaluation
@@ -298,6 +302,16 @@ def _generate_vars_for_machine_multi(
return machine_updated return machine_updated
def _generate_vars_for_machine(
machine: Machine,
generator_name: str | None,
regenerate: bool,
) -> bool:
return _generate_vars_for_machine_multi(
machine, [generator_name] if generator_name else [], regenerate
)
def generate_vars( def generate_vars(
machines: list[Machine], machines: list[Machine],
generator_name: str | None, generator_name: str | None,

View File

@@ -7,7 +7,7 @@ from clan_cli.errors import ClanError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def prompt(description: str, input_type: str) -> str: def ask(description: str, input_type: str) -> str:
if input_type == "line": if input_type == "line":
result = input(f"Enter the value for {description}: ") result = input(f"Enter the value for {description}: ")
elif input_type == "multiline": elif input_type == "multiline":

View File

@@ -7,7 +7,7 @@ from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.vars.get import get_var from clan_cli.vars.get import get_var
from .prompt import prompt from .prompt import ask
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ def set_command(machine: str, var_id: str, flake: FlakeId) -> None:
_machine = Machine(name=machine, flake=flake) _machine = Machine(name=machine, flake=flake)
var = get_var(_machine, var_id) var = get_var(_machine, var_id)
if sys.stdin.isatty(): if sys.stdin.isatty():
new_value = prompt(var.id, "hidden").encode("utf-8") new_value = ask(var.id, "hidden").encode("utf-8")
else: else:
new_value = sys.stdin.buffer.read() new_value = sys.stdin.buffer.read()
var.set(new_value) var.set(new_value)