diff --git a/pkgs/clan-app/ui/src/workflows/Install/steps/installSteps.tsx b/pkgs/clan-app/ui/src/workflows/Install/steps/installSteps.tsx index ef7d8f79c..ff3df0543 100644 --- a/pkgs/clan-app/ui/src/workflows/Install/steps/installSteps.tsx +++ b/pkgs/clan-app/ui/src/workflows/Install/steps/installSteps.tsx @@ -550,7 +550,7 @@ const InstallSummary = () => { } const runGenerators = client.fetch("run_generators", { - all_prompt_values: store.install.promptValues, + prompt_values: store.install.promptValues, machine: { name: store.install.machineName, flake: { diff --git a/pkgs/clan-cli/clan_cli/tests/test_vars.py b/pkgs/clan-cli/clan_cli/tests/test_vars.py index f83f15a45..32dc3a910 100644 --- a/pkgs/clan-cli/clan_cli/tests/test_vars.py +++ b/pkgs/clan-cli/clan_cli/tests/test_vars.py @@ -11,7 +11,6 @@ from clan_cli.vars.check import check_vars from clan_cli.vars.generate import ( Generator, GeneratorKey, - create_machine_vars_interactive, get_generators, run_generators, ) @@ -700,8 +699,8 @@ def test_api_set_prompts( run_generators( machine=Machine(name="my_machine", flake=Flake(str(flake.path))), - generators=["my_generator"], - all_prompt_values={ + generators=[GeneratorKey(machine="my_machine", name="my_generator")], + prompt_values={ "my_generator": { "prompt1": "input1", } @@ -714,8 +713,8 @@ def test_api_set_prompts( assert store.get(my_generator, "prompt1").decode() == "input1" run_generators( machine=Machine(name="my_machine", flake=Flake(str(flake.path))), - generators=["my_generator"], - all_prompt_values={ + generators=[GeneratorKey(machine="my_machine", name="my_generator")], + prompt_values={ "my_generator": { "prompt1": "input2", } @@ -757,14 +756,11 @@ def test_stdout_of_generate( flake_.refresh() monkeypatch.chdir(flake_.path) flake = Flake(str(flake_.path)) - from clan_cli.vars.generate import create_machine_vars_interactive - # with capture_output as output: with caplog.at_level(logging.INFO): - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=flake), - "my_generator", - regenerate=False, + generators=[GeneratorKey(machine="my_machine", name="my_generator")], ) assert "Updated var my_generator/my_value" in caplog.text @@ -774,10 +770,9 @@ def test_stdout_of_generate( set_var("my_machine", "my_generator/my_value", b"world", flake) with caplog.at_level(logging.INFO): - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=flake), - "my_generator", - regenerate=True, + generators=[GeneratorKey(machine="my_machine", name="my_generator")], ) assert "Updated var my_generator/my_value" in caplog.text assert "old: world" in caplog.text @@ -785,19 +780,17 @@ def test_stdout_of_generate( caplog.clear() # check the output when nothing gets regenerated with caplog.at_level(logging.INFO): - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=flake), - "my_generator", - regenerate=True, + generators=[GeneratorKey(machine="my_machine", name="my_generator")], ) assert "Updated var" not in caplog.text assert "hello" in caplog.text caplog.clear() with caplog.at_level(logging.INFO): - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=flake), - "my_secret_generator", - regenerate=False, + generators=[GeneratorKey(machine="my_machine", name="my_secret_generator")], ) assert "Updated secret var my_secret_generator/my_secret" in caplog.text assert "hello" not in caplog.text @@ -809,10 +802,9 @@ def test_stdout_of_generate( Flake(str(flake.path)), ) with caplog.at_level(logging.INFO): - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=flake), - "my_secret_generator", - regenerate=True, + generators=[GeneratorKey(machine="my_machine", name="my_secret_generator")], ) assert "Updated secret var my_secret_generator/my_secret" in caplog.text assert "world" not in caplog.text @@ -899,10 +891,9 @@ def test_fails_when_files_are_left_from_other_backend( flake.refresh() monkeypatch.chdir(flake.path) for generator in ["my_secret_generator", "my_value_generator"]: - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=Flake(str(flake.path))), - generator, - regenerate=False, + generators=GeneratorKey(machine="my_machine", name=generator), ) # Will raise. It was secret before, but now it's not. my_secret_generator["files"]["my_secret"]["secret"] = ( @@ -916,16 +907,14 @@ def test_fails_when_files_are_left_from_other_backend( # This should raise an error if generator == "my_secret_generator": with pytest.raises(ClanError): - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=Flake(str(flake.path))), - generator, - regenerate=False, + generators=GeneratorKey(machine="my_machine", name=generator), ) else: - create_machine_vars_interactive( + run_generators( Machine(name="my_machine", flake=Flake(str(flake.path))), - generator, - regenerate=False, + generators=GeneratorKey(machine="my_machine", name=generator), ) diff --git a/pkgs/clan-cli/clan_cli/vars/generate.py b/pkgs/clan-cli/clan_cli/vars/generate.py index 7d69f93cf..c79f3aa58 100644 --- a/pkgs/clan-cli/clan_cli/vars/generate.py +++ b/pkgs/clan-cli/clan_cli/vars/generate.py @@ -3,9 +3,11 @@ import logging import os import shutil import sys +from collections.abc import Callable from contextlib import ExitStack from pathlib import Path from tempfile import TemporaryDirectory +from typing import Literal from clan_cli.completions import ( add_dynamic_completer, @@ -333,7 +335,7 @@ def _ensure_healthy( def _generate_vars_for_machine( machine: "Machine", generators: list[Generator], - all_prompt_values: dict[str, dict[str, str]], + prompt_values: dict[str, dict[str, str]], no_sandbox: bool = False, ) -> None: _ensure_healthy(machine=machine, generators=generators) @@ -346,68 +348,76 @@ def _generate_vars_for_machine( generator=generator, secret_vars_store=machine.secret_vars_store, public_vars_store=machine.public_vars_store, - prompt_values=all_prompt_values.get(generator.name, {}), + prompt_values=prompt_values.get(generator.name, {}), no_sandbox=no_sandbox, ) +PromptFunc = Callable[[Generator], dict[str, str]] +"""Type for a function that collects prompt values for a generator. + +The function receives a Generator and should return a dictionary mapping +prompt names to their values. This allows for custom prompt collection +strategies (e.g., interactive CLI, GUI, or programmatic). +""" + + @API.register def run_generators( machine: Machine, - all_prompt_values: dict[str, dict[str, str]], - generators: list[str] | None = None, + generators: GeneratorKey + | list[GeneratorKey] + | Literal["all", "minimal"] = "minimal", + prompt_values: dict[str, dict[str, str]] | PromptFunc = _ask_prompts, no_sandbox: bool = False, ) -> None: """Run the specified generators for a machine. Args: - machine_name (str): The name of the machine. - generators (list[str]): The list of generator names to run. - all_prompt_values (dict[str, dict[str, str]]): A dictionary mapping generator names - to their prompt values. - base_dir (Path): The base directory of the flake. - no_sandbox (bool): Whether to disable sandboxing when executing the generator. - Returns: - bool: True if any variables were generated, False otherwise. + machine: The machine to run generators for. + generators: Can be: + - GeneratorKey: Single generator to run (ensuring dependencies are met) + - list[GeneratorKey]: Specific generators to run exactly as provided. + Dependency generators are not added automatically in this case. + The caller must ensure that all dependencies are included. + - "all": Run all generators (full closure) + - "minimal": Run only missing generators (minimal closure) (default) + prompt_values: A dictionary mapping generator names to their prompt values, + or a function that returns prompt values for a generator. + no_sandbox: Whether to disable sandboxing when executing the generator. Raises: ClanError: If the machine or generator is not found, or if there are issues with executing the generator. """ - if not generators: - generator_objects = Generator.get_machine_generators( - machine.name, machine.flake + if generators == "all": + generator_objects = get_generators(machine, full_closure=True) + elif generators == "minimal": + generator_objects = get_generators(machine, full_closure=False) + elif isinstance(generators, GeneratorKey): + # Single generator - compute minimal closure for it + generator_objects = get_generators( + machine, full_closure=False, generator_name=generators.name ) + elif isinstance(generators, list): + if len(generators) == 0: + return + generator_keys = set(generators) + all_generators = get_generators(machine, full_closure=True) + generator_objects = [g for g in all_generators if g.key in generator_keys] else: - generators_set = set(generators) - generator_objects = [ - g - for g in Generator.get_machine_generators(machine.name, machine.flake) - if g.name in generators_set - ] + msg = f"Invalid generators argument: {generators}. Must be 'all', 'minimal', GeneratorKey, or a list of GeneratorKey" + raise ValueError(msg) + + # If prompt function provided, ask all prompts + # TODO: make this more lazy and ask for every generator on execution + if callable(prompt_values): + prompt_values = { + generator.name: prompt_values(generator) for generator in generator_objects + } _generate_vars_for_machine( machine=machine, generators=generator_objects, - all_prompt_values=all_prompt_values, - no_sandbox=no_sandbox, - ) - - -def create_machine_vars_interactive( - machine: "Machine", - generator_name: str | None, - regenerate: bool, - no_sandbox: bool = False, -) -> None: - generators = get_generators(machine, regenerate, generator_name) - if len(generators) == 0: - return - all_prompt_values = {} - for generator in generators: - all_prompt_values[generator.name] = _ask_prompts(generator) - _generate_vars_for_machine( - machine, - generators, - all_prompt_values, + prompt_values=prompt_values, no_sandbox=no_sandbox, ) @@ -421,10 +431,15 @@ def generate_vars( for machine in machines: errors = [] try: - create_machine_vars_interactive( + generators: GeneratorKey | Literal["all", "minimal"] + if generator_name: + generators = GeneratorKey(machine=machine.name, name=generator_name) + else: + generators = "all" if regenerate else "minimal" + + run_generators( machine, - generator_name, - regenerate, + generators=generators, no_sandbox=no_sandbox, ) machine.info("All vars are up to date") diff --git a/pkgs/clan-cli/clan_lib/tests/test_create.py b/pkgs/clan-cli/clan_lib/tests/test_create.py index 66497d17a..f4da2a0d9 100644 --- a/pkgs/clan-cli/clan_lib/tests/test_create.py +++ b/pkgs/clan-cli/clan_lib/tests/test_create.py @@ -218,7 +218,7 @@ def test_clan_create_api( clan_dir_flake.invalidate_cache() generators = get_generators(machine=machine, full_closure=True) - all_prompt_values = {} + collected_prompt_values = {} for generator in generators: prompt_values = {} for prompt in generator.prompts: @@ -228,12 +228,12 @@ def test_clan_create_api( else: msg = f"Prompt {var_id} not handled in test, please fix it" raise ClanError(msg) - all_prompt_values[generator.name] = prompt_values + collected_prompt_values[generator.name] = prompt_values run_generators( machine=machine, - generators=[gen.name for gen in generators], - all_prompt_values=all_prompt_values, + generators=[gen.key for gen in generators], + prompt_values=collected_prompt_values, ) clan_dir_flake.invalidate_cache()