diff --git a/pkgs/clan-cli/clan_cli/vars/list.py b/pkgs/clan-cli/clan_cli/vars/list.py index 59f10af58..5a7b7d401 100644 --- a/pkgs/clan-cli/clan_cli/vars/list.py +++ b/pkgs/clan-cli/clan_cli/vars/list.py @@ -3,6 +3,7 @@ import importlib import logging from clan_cli.api import API +from clan_cli.clan_uri import FlakeId from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.errors import ClanError from clan_cli.machines.machines import Machine @@ -57,8 +58,8 @@ def _get_previous_value( @API.register -# TODO: use machine_name -def get_prompts(machine: Machine) -> list[Generator]: +def get_prompts(base_dir: str, machine_name: str) -> list[Generator]: + machine = Machine(name=machine_name, flake=FlakeId(base_dir)) generators: list[Generator] = machine.vars_generators for generator in generators: for prompt in generator.prompts: @@ -70,7 +71,10 @@ def get_prompts(machine: Machine) -> list[Generator]: # TODO: for missing prompts, default to existing values # TODO: raise error if mandatory prompt not provided @API.register -def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None: +def set_prompts( + base_dir: str, machine_name: str, updates: list[GeneratorUpdate] +) -> None: + machine = Machine(name=machine_name, flake=FlakeId(base_dir)) for update in updates: for generator in machine.vars_generators: if generator.name == update.generator: diff --git a/pkgs/clan-cli/tests/test_vars.py b/pkgs/clan-cli/tests/test_vars.py index 2a3714960..c7aee65ed 100644 --- a/pkgs/clan-cli/tests/test_vars.py +++ b/pkgs/clan-cli/tests/test_vars.py @@ -597,23 +597,24 @@ def test_api_set_prompts( flake.refresh() monkeypatch.chdir(flake.path) - machine = Machine(name="my_machine", flake=FlakeId(str(flake.path))) + params = {"machine_name": "my_machine", "base_dir": str(flake.path)} set_prompts( - machine, - [ + **params, + updates=[ GeneratorUpdate( generator="my_generator", prompt_values={"prompt1": "input1"}, ) ], ) + machine = Machine(name="my_machine", flake=FlakeId(str(flake.path))) store = in_repo.FactStore(machine) assert store.exists(Generator("my_generator"), "prompt1") assert store.get(Generator("my_generator"), "prompt1").decode() == "input1" set_prompts( - machine, - [ + **params, + updates=[ GeneratorUpdate( generator="my_generator", prompt_values={"prompt1": "input2"}, @@ -622,7 +623,7 @@ def test_api_set_prompts( ) assert store.get(Generator("my_generator"), "prompt1").decode() == "input2" - api_prompts = get_prompts(machine) + api_prompts = get_prompts(**params) assert len(api_prompts) == 1 assert api_prompts[0].name == "my_generator" assert api_prompts[0].prompts[0].name == "prompt1"