Merge pull request 'Vars: remove spurious 'fake_prompt' with mocked method' (#4659) from cleaner into main

Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/4659
This commit is contained in:
hsjobeki
2025-08-09 22:53:18 +00:00
2 changed files with 33 additions and 38 deletions

View File

@@ -30,7 +30,7 @@ from .graph import (
minimal_closure, minimal_closure,
requested_closure, requested_closure,
) )
from .prompt import Prompt, PromptType, ask from .prompt import Prompt, ask
from .var import Var from .var import Var
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -385,26 +385,6 @@ def _ask_prompts(
return prompt_values return prompt_values
def _fake_prompts(
generator: Generator,
) -> dict[str, str]:
prompt_values: dict[str, str] = {}
for prompt in generator.prompts:
var_id = f"{generator.name}/{prompt.name}"
if prompt.prompt_type == PromptType.HIDDEN:
prompt_values[prompt.name] = "fake_hidden_value"
elif prompt.prompt_type == PromptType.MULTILINE_HIDDEN:
prompt_values[prompt.name] = "fake\nmultiline\nhidden\nvalue"
elif prompt.prompt_type == PromptType.MULTILINE:
prompt_values[prompt.name] = "fake\nmultiline\nvalue"
elif prompt.prompt_type == PromptType.LINE:
prompt_values[prompt.name] = "fake_line_value"
else:
msg = f"Unknown prompt type {prompt.prompt_type} for prompt {var_id} in generator {generator.name}"
raise ClanError(msg)
return prompt_values
def _get_previous_value( def _get_previous_value(
machine: "Machine", machine: "Machine",
generator: Generator, generator: Generator,
@@ -550,7 +530,6 @@ def create_machine_vars_interactive(
generator_name: str | None, generator_name: str | None,
regenerate: bool, regenerate: bool,
no_sandbox: bool = False, no_sandbox: bool = False,
fake_prompts: bool = False,
) -> bool: ) -> bool:
_generator = None _generator = None
if generator_name: if generator_name:
@@ -580,10 +559,7 @@ def create_machine_vars_interactive(
return False return False
all_prompt_values = {} all_prompt_values = {}
for generator in generators: for generator in generators:
if fake_prompts: all_prompt_values[generator.name] = _ask_prompts(generator)
all_prompt_values[generator.name] = _fake_prompts(generator)
else:
all_prompt_values[generator.name] = _ask_prompts(generator)
return _generate_vars_for_machine( return _generate_vars_for_machine(
machine, machine,
generators, generators,
@@ -597,8 +573,7 @@ def generate_vars(
generator_name: str | None = None, generator_name: str | None = None,
regenerate: bool = False, regenerate: bool = False,
no_sandbox: bool = False, no_sandbox: bool = False,
fake_prompts: bool = False, ) -> None:
) -> bool:
was_regenerated = False was_regenerated = False
for machine in machines: for machine in machines:
errors = [] errors = []
@@ -608,7 +583,6 @@ def generate_vars(
generator_name, generator_name,
regenerate, regenerate,
no_sandbox=no_sandbox, no_sandbox=no_sandbox,
fake_prompts=fake_prompts,
) )
except Exception as exc: except Exception as exc:
errors += [(machine, exc)] errors += [(machine, exc)]
@@ -624,8 +598,6 @@ def generate_vars(
for machine in machines: for machine in machines:
machine.info("All vars are already up to date") machine.info("All vars are already up to date")
return was_regenerated
def generate_command(args: argparse.Namespace) -> None: def generate_command(args: argparse.Namespace) -> None:
flake = require_flake(args.flake) flake = require_flake(args.flake)
@@ -649,15 +621,12 @@ def generate_command(args: argparse.Namespace) -> None:
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.generators.*.validationHash", f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.generators.*.validationHash",
] ]
) )
has_changed = generate_vars( generate_vars(
machines, machines,
args.generator, args.generator,
args.regenerate, args.regenerate,
no_sandbox=args.no_sandbox, no_sandbox=args.no_sandbox,
fake_prompts=args.fake_prompts,
) )
if has_changed:
flake.invalidate_cache()
def register_generate_parser(parser: argparse.ArgumentParser) -> None: def register_generate_parser(parser: argparse.ArgumentParser) -> None:

View File

@@ -10,9 +10,12 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, override from typing import Any, override
from unittest.mock import patch
from clan_cli.vars.generate import generate_vars from clan_cli.vars.generate import Generator, generate_vars
from clan_cli.vars.prompt import PromptType
from clan_lib.dirs import find_toplevel from clan_lib.dirs import find_toplevel
from clan_lib.errors import ClanError
from clan_lib.flake.flake import Flake from clan_lib.flake.flake import Flake
from clan_lib.machines.machines import Machine from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_config, nix_eval, nix_test_store from clan_lib.nix import nix_config, nix_eval, nix_test_store
@@ -218,13 +221,36 @@ def main() -> None:
) )
+ "\n" + "\n"
) )
with NamedTemporaryFile("w") as f:
def mocked_prompts(
generator: Generator,
) -> dict[str, str]:
prompt_values: dict[str, str] = {}
for prompt in generator.prompts:
var_id = f"{generator.name}/{prompt.name}"
if prompt.prompt_type == PromptType.HIDDEN:
prompt_values[prompt.name] = "fake_hidden_value"
elif prompt.prompt_type == PromptType.MULTILINE_HIDDEN:
prompt_values[prompt.name] = "fake\nmultiline\nhidden\nvalue"
elif prompt.prompt_type == PromptType.MULTILINE:
prompt_values[prompt.name] = "fake\nmultiline\nvalue"
elif prompt.prompt_type == PromptType.LINE:
prompt_values[prompt.name] = "fake_line_value"
else:
msg = f"Unknown prompt type {prompt.prompt_type} for prompt {var_id} in generator {generator.name}"
raise ClanError(msg)
return prompt_values
with (
patch("clan_cli.vars.generate._ask_prompts", new=mocked_prompts),
NamedTemporaryFile("w") as f,
):
f.write("# created: 2023-07-17T10:51:45+02:00\n") f.write("# created: 2023-07-17T10:51:45+02:00\n")
f.write(f"# public key: {sops_pub_key}\n") f.write(f"# public key: {sops_pub_key}\n")
f.write(sops_priv_key) f.write(sops_priv_key)
f.seek(0) f.seek(0)
os.environ["SOPS_AGE_KEY_FILE"] = f.name os.environ["SOPS_AGE_KEY_FILE"] = f.name
generate_vars(list(machines), fake_prompts=True) generate_vars(list(machines))
if __name__ == "__main__": if __name__ == "__main__":