API/vars: simplify get/set prompts
This commit is contained in:
@@ -3,6 +3,7 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from clan_cli.api import API
|
from clan_cli.api import API
|
||||||
|
from clan_cli.clan_uri import FlakeId
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_cli.errors import ClanError
|
from clan_cli.errors import ClanError
|
||||||
from clan_cli.machines.machines import Machine
|
from clan_cli.machines.machines import Machine
|
||||||
@@ -57,8 +58,8 @@ def _get_previous_value(
|
|||||||
|
|
||||||
|
|
||||||
@API.register
|
@API.register
|
||||||
# TODO: use machine_name
|
def get_prompts(base_dir: str, machine_name: str) -> list[Generator]:
|
||||||
def get_prompts(machine: Machine) -> list[Generator]:
|
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
|
||||||
generators: list[Generator] = machine.vars_generators
|
generators: list[Generator] = machine.vars_generators
|
||||||
for generator in generators:
|
for generator in generators:
|
||||||
for prompt in generator.prompts:
|
for prompt in generator.prompts:
|
||||||
@@ -70,7 +71,10 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
|||||||
# TODO: for missing prompts, default to existing values
|
# TODO: for missing prompts, default to existing values
|
||||||
# TODO: raise error if mandatory prompt not provided
|
# TODO: raise error if mandatory prompt not provided
|
||||||
@API.register
|
@API.register
|
||||||
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
|
def set_prompts(
|
||||||
|
base_dir: str, machine_name: str, updates: list[GeneratorUpdate]
|
||||||
|
) -> None:
|
||||||
|
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
|
||||||
for update in updates:
|
for update in updates:
|
||||||
for generator in machine.vars_generators:
|
for generator in machine.vars_generators:
|
||||||
if generator.name == update.generator:
|
if generator.name == update.generator:
|
||||||
|
|||||||
@@ -597,23 +597,24 @@ def test_api_set_prompts(
|
|||||||
flake.refresh()
|
flake.refresh()
|
||||||
|
|
||||||
monkeypatch.chdir(flake.path)
|
monkeypatch.chdir(flake.path)
|
||||||
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
|
params = {"machine_name": "my_machine", "base_dir": str(flake.path)}
|
||||||
|
|
||||||
set_prompts(
|
set_prompts(
|
||||||
machine,
|
**params,
|
||||||
[
|
updates=[
|
||||||
GeneratorUpdate(
|
GeneratorUpdate(
|
||||||
generator="my_generator",
|
generator="my_generator",
|
||||||
prompt_values={"prompt1": "input1"},
|
prompt_values={"prompt1": "input1"},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
|
||||||
store = in_repo.FactStore(machine)
|
store = in_repo.FactStore(machine)
|
||||||
assert store.exists(Generator("my_generator"), "prompt1")
|
assert store.exists(Generator("my_generator"), "prompt1")
|
||||||
assert store.get(Generator("my_generator"), "prompt1").decode() == "input1"
|
assert store.get(Generator("my_generator"), "prompt1").decode() == "input1"
|
||||||
set_prompts(
|
set_prompts(
|
||||||
machine,
|
**params,
|
||||||
[
|
updates=[
|
||||||
GeneratorUpdate(
|
GeneratorUpdate(
|
||||||
generator="my_generator",
|
generator="my_generator",
|
||||||
prompt_values={"prompt1": "input2"},
|
prompt_values={"prompt1": "input2"},
|
||||||
@@ -622,7 +623,7 @@ def test_api_set_prompts(
|
|||||||
)
|
)
|
||||||
assert store.get(Generator("my_generator"), "prompt1").decode() == "input2"
|
assert store.get(Generator("my_generator"), "prompt1").decode() == "input2"
|
||||||
|
|
||||||
api_prompts = get_prompts(machine)
|
api_prompts = get_prompts(**params)
|
||||||
assert len(api_prompts) == 1
|
assert len(api_prompts) == 1
|
||||||
assert api_prompts[0].name == "my_generator"
|
assert api_prompts[0].name == "my_generator"
|
||||||
assert api_prompts[0].prompts[0].name == "prompt1"
|
assert api_prompts[0].prompts[0].name == "prompt1"
|
||||||
|
|||||||
Reference in New Issue
Block a user