API/vars: simplify get/set prompts

This commit is contained in:
Johannes Kirschbauer
2025-01-07 11:09:41 +01:00
parent 5497a6e44b
commit 6b209f1008
2 changed files with 14 additions and 9 deletions

View File

@@ -3,6 +3,7 @@ import importlib
import logging
from clan_cli.api import API
from clan_cli.clan_uri import FlakeId
from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.errors import ClanError
from clan_cli.machines.machines import Machine
@@ -57,8 +58,8 @@ def _get_previous_value(
@API.register
# TODO: use machine_name
def get_prompts(machine: Machine) -> list[Generator]:
def get_prompts(base_dir: str, machine_name: str) -> list[Generator]:
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
generators: list[Generator] = machine.vars_generators
for generator in generators:
for prompt in generator.prompts:
@@ -70,7 +71,10 @@ def get_prompts(machine: Machine) -> list[Generator]:
# TODO: for missing prompts, default to existing values
# TODO: raise error if mandatory prompt not provided
@API.register
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
def set_prompts(
base_dir: str, machine_name: str, updates: list[GeneratorUpdate]
) -> None:
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
for update in updates:
for generator in machine.vars_generators:
if generator.name == update.generator:

View File

@@ -597,23 +597,24 @@ def test_api_set_prompts(
flake.refresh()
monkeypatch.chdir(flake.path)
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
params = {"machine_name": "my_machine", "base_dir": str(flake.path)}
set_prompts(
machine,
[
**params,
updates=[
GeneratorUpdate(
generator="my_generator",
prompt_values={"prompt1": "input1"},
)
],
)
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
store = in_repo.FactStore(machine)
assert store.exists(Generator("my_generator"), "prompt1")
assert store.get(Generator("my_generator"), "prompt1").decode() == "input1"
set_prompts(
machine,
[
**params,
updates=[
GeneratorUpdate(
generator="my_generator",
prompt_values={"prompt1": "input2"},
@@ -622,7 +623,7 @@ def test_api_set_prompts(
)
assert store.get(Generator("my_generator"), "prompt1").decode() == "input2"
api_prompts = get_prompts(machine)
api_prompts = get_prompts(**params)
assert len(api_prompts) == 1
assert api_prompts[0].name == "my_generator"
assert api_prompts[0].prompts[0].name == "prompt1"