API/vars: use string based interfaces to get and set vars to avoid state mutations
This commit is contained in:
committed by
hsjobeki
parent
b6059fc506
commit
c6fe4f2625
@@ -260,7 +260,11 @@ def _ask_prompts(
|
||||
prompt_values: dict[str, str] = {}
|
||||
for prompt in generator.prompts:
|
||||
var_id = f"{generator.name}/{prompt.name}"
|
||||
prompt_values[prompt.name] = ask(var_id, prompt.prompt_type)
|
||||
prompt_values[prompt.name] = ask(
|
||||
var_id,
|
||||
prompt.prompt_type,
|
||||
prompt.description if prompt.description != prompt.name else None,
|
||||
)
|
||||
return prompt_values
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from clan_cli.api import API
|
||||
from clan_cli.clan_uri import FlakeId
|
||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.machines.machines import Machine
|
||||
|
||||
from .generate import Var
|
||||
from .list import get_vars
|
||||
@@ -13,8 +13,9 @@ from .list import get_vars
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_var(machine: Machine, var_id: str) -> Var:
|
||||
vars_ = get_vars(machine)
|
||||
@API.register
|
||||
def get_var(base_dir: str, machine_name: str, var_id: str) -> Var:
|
||||
vars_ = get_vars(base_dir=base_dir, machine_name=machine_name)
|
||||
results = []
|
||||
for var in vars_:
|
||||
if var.id == var_id:
|
||||
@@ -42,8 +43,7 @@ def get_var(machine: Machine, var_id: str) -> Var:
|
||||
|
||||
|
||||
def get_command(machine_name: str, var_id: str, flake: FlakeId) -> None:
|
||||
machine = Machine(name=machine_name, flake=flake)
|
||||
var = get_var(machine, var_id)
|
||||
var = get_var(str(flake.path), machine_name, var_id)
|
||||
if not var.exists:
|
||||
msg = f"Var {var.id} has not been generated yet"
|
||||
raise ClanError(msg)
|
||||
|
||||
@@ -25,7 +25,9 @@ def secret_store(machine: Machine) -> StoreBase:
|
||||
return secret_vars_module.SecretStore(machine=machine)
|
||||
|
||||
|
||||
def get_vars(machine: Machine) -> list[Var]:
|
||||
@API.register
|
||||
def get_vars(base_dir: str, machine_name: str) -> list[Var]:
|
||||
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
|
||||
pub_store = public_store(machine)
|
||||
sec_store = secret_store(machine)
|
||||
all_vars = []
|
||||
@@ -58,7 +60,7 @@ def _get_previous_value(
|
||||
|
||||
|
||||
@API.register
|
||||
def get_prompts(base_dir: str, machine_name: str) -> list[Generator]:
|
||||
def get_generators(base_dir: str, machine_name: str) -> list[Generator]:
|
||||
machine = Machine(name=machine_name, flake=FlakeId(base_dir))
|
||||
generators: list[Generator] = machine.vars_generators
|
||||
for generator in generators:
|
||||
@@ -96,7 +98,7 @@ def stringify_vars(_vars: list[Var]) -> str:
|
||||
|
||||
|
||||
def stringify_all_vars(machine: Machine) -> str:
|
||||
return stringify_vars(get_vars(machine))
|
||||
return stringify_vars(get_vars(str(machine.flake), machine.name))
|
||||
|
||||
|
||||
def list_command(args: argparse.Namespace) -> None:
|
||||
|
||||
@@ -37,16 +37,25 @@ class Prompt:
|
||||
)
|
||||
|
||||
|
||||
def ask(description: str, input_type: PromptType) -> str:
|
||||
def ask(
|
||||
ident: str,
|
||||
input_type: PromptType,
|
||||
label: str | None,
|
||||
) -> str:
|
||||
text = f"Enter the value for {ident}:"
|
||||
if label:
|
||||
text = f"{label}"
|
||||
|
||||
if MOCK_PROMPT_RESPONSE:
|
||||
return next(MOCK_PROMPT_RESPONSE)
|
||||
match input_type:
|
||||
case PromptType.LINE:
|
||||
result = input(f"Enter the value for {description}: ")
|
||||
result = input(f"{text}: ")
|
||||
case PromptType.MULTILINE:
|
||||
print(f"Enter the value for {description} (Finish with Ctrl-D): ")
|
||||
print(f"{text} (Finish with Ctrl-D): ")
|
||||
result = sys.stdin.read()
|
||||
case PromptType.HIDDEN:
|
||||
result = getpass(f"Enter the value for {description} (hidden): ")
|
||||
result = getpass(f"{text} (hidden): ")
|
||||
|
||||
log.info("Input received. Processing...")
|
||||
return result
|
||||
|
||||
@@ -23,7 +23,7 @@ def set_var(
|
||||
else:
|
||||
_machine = machine
|
||||
if isinstance(var, str):
|
||||
_var = get_var(_machine, var)
|
||||
_var = get_var(str(flake.path), _machine.name, var)
|
||||
else:
|
||||
_var = var
|
||||
path = _var.set(value)
|
||||
@@ -36,12 +36,17 @@ def set_var(
|
||||
|
||||
|
||||
def set_via_stdin(machine: str, var_id: str, flake: FlakeId) -> None:
|
||||
_machine = Machine(name=machine, flake=flake)
|
||||
var = get_var(_machine, var_id)
|
||||
var = get_var(str(flake.path), machine, var_id)
|
||||
if sys.stdin.isatty():
|
||||
new_value = ask(var.id, PromptType.HIDDEN).encode("utf-8")
|
||||
new_value = ask(
|
||||
var.id,
|
||||
PromptType.HIDDEN,
|
||||
None,
|
||||
).encode("utf-8")
|
||||
else:
|
||||
new_value = sys.stdin.buffer.read()
|
||||
|
||||
_machine = Machine(name=machine, flake=flake)
|
||||
set_var(_machine, var, new_value, flake)
|
||||
|
||||
|
||||
|
||||
@@ -170,9 +170,16 @@ def test_generate_public_and_secret_vars(
|
||||
"Update vars via generator my_shared_generator for machine my_machine"
|
||||
in commit_message
|
||||
)
|
||||
assert get_var(machine, "my_generator/my_value").printable_value == "public"
|
||||
assert (
|
||||
get_var(machine, "my_shared_generator/my_shared_value").printable_value
|
||||
get_var(
|
||||
str(machine.flake.path), machine.name, "my_generator/my_value"
|
||||
).printable_value
|
||||
== "public"
|
||||
)
|
||||
assert (
|
||||
get_var(
|
||||
str(machine.flake.path), machine.name, "my_shared_generator/my_shared_value"
|
||||
).printable_value
|
||||
== "shared"
|
||||
)
|
||||
vars_text = stringify_all_vars(machine)
|
||||
@@ -587,7 +594,7 @@ def test_api_set_prompts(
|
||||
flake: ClanFlake,
|
||||
) -> None:
|
||||
from clan_cli.vars._types import GeneratorUpdate
|
||||
from clan_cli.vars.list import get_prompts, set_prompts
|
||||
from clan_cli.vars.list import get_generators, set_prompts
|
||||
|
||||
config = flake.machines["my_machine"]
|
||||
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
|
||||
@@ -623,11 +630,11 @@ def test_api_set_prompts(
|
||||
)
|
||||
assert store.get(Generator("my_generator"), "prompt1").decode() == "input2"
|
||||
|
||||
api_prompts = get_prompts(**params)
|
||||
assert len(api_prompts) == 1
|
||||
assert api_prompts[0].name == "my_generator"
|
||||
assert api_prompts[0].prompts[0].name == "prompt1"
|
||||
assert api_prompts[0].prompts[0].previous_value == "input2"
|
||||
generators = get_generators(**params)
|
||||
assert len(generators) == 1
|
||||
assert generators[0].name == "my_generator"
|
||||
assert generators[0].prompts[0].name == "prompt1"
|
||||
assert generators[0].prompts[0].previous_value == "input2"
|
||||
|
||||
|
||||
@pytest.mark.with_core
|
||||
@@ -843,19 +850,27 @@ def test_invalidation(
|
||||
monkeypatch.chdir(flake.path)
|
||||
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
|
||||
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
|
||||
value1 = get_var(machine, "my_generator/my_value").printable_value
|
||||
value1 = get_var(
|
||||
str(machine.flake.path), machine.name, "my_generator/my_value"
|
||||
).printable_value
|
||||
# generate again and make sure nothing changes without the invalidation data being set
|
||||
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
|
||||
value1_new = get_var(machine, "my_generator/my_value").printable_value
|
||||
value1_new = get_var(
|
||||
str(machine.flake.path), machine.name, "my_generator/my_value"
|
||||
).printable_value
|
||||
assert value1 == value1_new
|
||||
# set the invalidation data of the generator
|
||||
my_generator["validation"] = 1
|
||||
flake.refresh()
|
||||
# generate again and make sure the value changes
|
||||
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
|
||||
value2 = get_var(machine, "my_generator/my_value").printable_value
|
||||
value2 = get_var(
|
||||
str(machine.flake.path), machine.name, "my_generator/my_value"
|
||||
).printable_value
|
||||
assert value1 != value2
|
||||
# generate again without changing invalidation data -> value should not change
|
||||
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
|
||||
value2_new = get_var(machine, "my_generator/my_value").printable_value
|
||||
value2_new = get_var(
|
||||
str(machine.flake.path), machine.name, "my_generator/my_value"
|
||||
).printable_value
|
||||
assert value2 == value2_new
|
||||
|
||||
Reference in New Issue
Block a user