Merge pull request 'API/vars: simplify get/set prompts' (#2695) from hsjobeki/clan-core:hsjobeki-main into main
This commit is contained in:
@@ -3,6 +3,7 @@ import importlib
|
||||
import logging
|
||||
|
||||
from clan_cli.api import API
|
||||
from clan_cli.clan_uri import FlakeId
|
||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.machines.machines import Machine
|
||||
@@ -57,8 +58,8 @@ def _get_previous_value(
|
||||
|
||||
|
||||
@API.register
|
||||
# TODO: use machine_name
|
||||
def get_prompts(machine: Machine) -> list[Generator]:
|
||||
def get_prompts(base_dir: str, machine_name: str) -> list[Generator]:
|
||||
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
|
||||
generators: list[Generator] = machine.vars_generators
|
||||
for generator in generators:
|
||||
for prompt in generator.prompts:
|
||||
@@ -70,7 +71,10 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
||||
# TODO: for missing prompts, default to existing values
|
||||
# TODO: raise error if mandatory prompt not provided
|
||||
@API.register
|
||||
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
|
||||
def set_prompts(
|
||||
base_dir: str, machine_name: str, updates: list[GeneratorUpdate]
|
||||
) -> None:
|
||||
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
|
||||
for update in updates:
|
||||
for generator in machine.vars_generators:
|
||||
if generator.name == update.generator:
|
||||
|
||||
@@ -597,23 +597,24 @@ def test_api_set_prompts(
|
||||
flake.refresh()
|
||||
|
||||
monkeypatch.chdir(flake.path)
|
||||
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
|
||||
params = {"machine_name": "my_machine", "base_dir": str(flake.path)}
|
||||
|
||||
set_prompts(
|
||||
machine,
|
||||
[
|
||||
**params,
|
||||
updates=[
|
||||
GeneratorUpdate(
|
||||
generator="my_generator",
|
||||
prompt_values={"prompt1": "input1"},
|
||||
)
|
||||
],
|
||||
)
|
||||
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
|
||||
store = in_repo.FactStore(machine)
|
||||
assert store.exists(Generator("my_generator"), "prompt1")
|
||||
assert store.get(Generator("my_generator"), "prompt1").decode() == "input1"
|
||||
set_prompts(
|
||||
machine,
|
||||
[
|
||||
**params,
|
||||
updates=[
|
||||
GeneratorUpdate(
|
||||
generator="my_generator",
|
||||
prompt_values={"prompt1": "input2"},
|
||||
@@ -622,7 +623,7 @@ def test_api_set_prompts(
|
||||
)
|
||||
assert store.get(Generator("my_generator"), "prompt1").decode() == "input2"
|
||||
|
||||
api_prompts = get_prompts(machine)
|
||||
api_prompts = get_prompts(**params)
|
||||
assert len(api_prompts) == 1
|
||||
assert api_prompts[0].name == "my_generator"
|
||||
assert api_prompts[0].prompts[0].name == "prompt1"
|
||||
|
||||
Reference in New Issue
Block a user