Merge pull request 'vars: add api endpoint set_prompts' (#2044) from DavHau/clan-core:DavHau-dave into main
This commit is contained in:
@@ -25,6 +25,12 @@ class Generator:
|
|||||||
prompts: list[Prompt]
|
prompts: list[Prompt]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneratorUpdate:
|
||||||
|
generator: str
|
||||||
|
prompt_values: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Var:
|
class Var:
|
||||||
_store: "StoreBase"
|
_store: "StoreBase"
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ def execute_generator(
|
|||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
secret_vars_store: SecretStoreBase,
|
secret_vars_store: SecretStoreBase,
|
||||||
public_vars_store: FactStoreBase,
|
public_vars_store: FactStoreBase,
|
||||||
|
prompt_values: dict[str, str] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# check if all secrets exist and generate them if at least one is missing
|
# check if all secrets exist and generate them if at least one is missing
|
||||||
needs_regeneration = not check_vars(machine, generator_name=generator_name)
|
needs_regeneration = not check_vars(machine, generator_name=generator_name)
|
||||||
@@ -116,6 +117,20 @@ def execute_generator(
|
|||||||
decrypted_dependencies = decrypt_dependencies(
|
decrypted_dependencies = decrypt_dependencies(
|
||||||
machine, generator_name, secret_vars_store, public_vars_store, shared=is_shared
|
machine, generator_name, secret_vars_store, public_vars_store, shared=is_shared
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_prompt_value(prompt_name: str) -> str:
|
||||||
|
if prompt_values:
|
||||||
|
try:
|
||||||
|
return prompt_values[prompt_name]
|
||||||
|
except KeyError as e:
|
||||||
|
msg = f"prompt value for '{prompt_name}' in generator {generator_name} not provided"
|
||||||
|
raise ClanError(msg) from e
|
||||||
|
description = machine.vars_generators[generator_name]["prompts"][prompt_name][
|
||||||
|
"description"
|
||||||
|
]
|
||||||
|
_type = machine.vars_generators[generator_name]["prompts"][prompt_name]["type"]
|
||||||
|
return prompt(description, _type)
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
with TemporaryDirectory() as tmp:
|
with TemporaryDirectory() as tmp:
|
||||||
tmpdir = Path(tmp)
|
tmpdir = Path(tmp)
|
||||||
@@ -133,11 +148,9 @@ def execute_generator(
|
|||||||
if machine.vars_generators[generator_name]["prompts"]:
|
if machine.vars_generators[generator_name]["prompts"]:
|
||||||
tmpdir_prompts.mkdir()
|
tmpdir_prompts.mkdir()
|
||||||
env["prompts"] = str(tmpdir_prompts)
|
env["prompts"] = str(tmpdir_prompts)
|
||||||
for prompt_name, prompt_ in machine.vars_generators[generator_name][
|
for prompt_name in machine.vars_generators[generator_name]["prompts"]:
|
||||||
"prompts"
|
|
||||||
].items():
|
|
||||||
prompt_file = tmpdir_prompts / prompt_name
|
prompt_file = tmpdir_prompts / prompt_name
|
||||||
value = prompt(prompt_["description"], prompt_["type"])
|
value = get_prompt_value(prompt_name)
|
||||||
prompt_file.write_text(value)
|
prompt_file.write_text(value)
|
||||||
|
|
||||||
if sys.platform == "linux":
|
if sys.platform == "linux":
|
||||||
|
|||||||
@@ -6,17 +6,20 @@ from clan_cli.api import API
|
|||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_cli.machines.machines import Machine
|
from clan_cli.machines.machines import Machine
|
||||||
|
|
||||||
from ._types import Generator, Prompt, StoreBase, Var
|
from ._types import Generator, GeneratorUpdate, Prompt, Var
|
||||||
|
from .generate import execute_generator
|
||||||
|
from .public_modules import FactStoreBase
|
||||||
|
from .secret_modules import SecretStoreBase
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def public_store(machine: Machine) -> StoreBase:
|
def public_store(machine: Machine) -> FactStoreBase:
|
||||||
public_vars_module = importlib.import_module(machine.public_vars_module)
|
public_vars_module = importlib.import_module(machine.public_vars_module)
|
||||||
return public_vars_module.FactStore(machine=machine)
|
return public_vars_module.FactStore(machine=machine)
|
||||||
|
|
||||||
|
|
||||||
def secret_store(machine: Machine) -> StoreBase:
|
def secret_store(machine: Machine) -> SecretStoreBase:
|
||||||
secret_vars_module = importlib.import_module(machine.secret_vars_module)
|
secret_vars_module = importlib.import_module(machine.secret_vars_module)
|
||||||
return secret_vars_module.SecretStore(machine=machine)
|
return secret_vars_module.SecretStore(machine=machine)
|
||||||
|
|
||||||
@@ -65,6 +68,20 @@ def get_prompts(machine: Machine) -> list[Generator]:
|
|||||||
return generators
|
return generators
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Ensure generator dependencies are met (executed in correct order etc.)
|
||||||
|
@API.register
|
||||||
|
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
|
||||||
|
for update in updates:
|
||||||
|
execute_generator(
|
||||||
|
machine,
|
||||||
|
update.generator,
|
||||||
|
regenerate=True,
|
||||||
|
secret_vars_store=secret_store(machine),
|
||||||
|
public_vars_store=public_store(machine),
|
||||||
|
prompt_values=update.prompt_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def stringify_vars(_vars: list[Var]) -> str:
|
def stringify_vars(_vars: list[Var]) -> str:
|
||||||
return "\n".join([str(var) for var in _vars])
|
return "\n".join([str(var) for var in _vars])
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from .prompt import prompt
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_command(machine: str, var_id: str, flake: FlakeId) -> None:
|
def set_command(machine: str, var_id: str, flake: FlakeId) -> None:
|
||||||
_machine = Machine(name=machine, flake=flake)
|
_machine = Machine(name=machine, flake=flake)
|
||||||
var = get_var(_machine, var_id)
|
var = get_var(_machine, var_id)
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
@@ -22,8 +22,8 @@ def get_command(machine: str, var_id: str, flake: FlakeId) -> None:
|
|||||||
var.set(new_value)
|
var.set(new_value)
|
||||||
|
|
||||||
|
|
||||||
def _get_command(args: argparse.Namespace) -> None:
|
def _set_command(args: argparse.Namespace) -> None:
|
||||||
get_command(args.machine, args.var_id, args.flake)
|
set_command(args.machine, args.var_id, args.flake)
|
||||||
|
|
||||||
|
|
||||||
def register_set_parser(parser: argparse.ArgumentParser) -> None:
|
def register_set_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
@@ -38,4 +38,4 @@ def register_set_parser(parser: argparse.ArgumentParser) -> None:
|
|||||||
help="The var id for which to set the value. Example: ssh-keys/pubkey",
|
help="The var id for which to set the value. Example: ssh-keys/pubkey",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.set_defaults(func=_get_command)
|
parser.set_defaults(func=_set_command)
|
||||||
|
|||||||
@@ -436,3 +436,46 @@ def test_api_get_prompts(
|
|||||||
assert api_prompts[0].name == "my_generator"
|
assert api_prompts[0].name == "my_generator"
|
||||||
assert api_prompts[0].prompts[0].name == "prompt1"
|
assert api_prompts[0].prompts[0].name == "prompt1"
|
||||||
assert api_prompts[0].prompts[0].previous_value == "input1"
|
assert api_prompts[0].prompts[0].previous_value == "input1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.impure
|
||||||
|
def test_api_set_prompts(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
temporary_home: Path,
|
||||||
|
) -> None:
|
||||||
|
from clan_cli.vars._types import GeneratorUpdate
|
||||||
|
from clan_cli.vars.list import set_prompts
|
||||||
|
|
||||||
|
config = nested_dict()
|
||||||
|
my_generator = config["clan"]["core"]["vars"]["generators"]["my_generator"]
|
||||||
|
my_generator["prompts"]["prompt1"]["type"] = "line"
|
||||||
|
my_generator["files"]["prompt1"]["secret"] = False
|
||||||
|
flake = generate_flake(
|
||||||
|
temporary_home,
|
||||||
|
flake_template=CLAN_CORE / "templates" / "minimal",
|
||||||
|
machine_configs={"my_machine": config},
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(flake.path)
|
||||||
|
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
|
||||||
|
set_prompts(
|
||||||
|
machine,
|
||||||
|
[
|
||||||
|
GeneratorUpdate(
|
||||||
|
generator="my_generator",
|
||||||
|
prompt_values={"prompt1": "input1"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
store = in_repo.FactStore(machine)
|
||||||
|
assert store.exists("my_generator", "prompt1")
|
||||||
|
assert store.get("my_generator", "prompt1").decode() == "input1"
|
||||||
|
set_prompts(
|
||||||
|
machine,
|
||||||
|
[
|
||||||
|
GeneratorUpdate(
|
||||||
|
generator="my_generator",
|
||||||
|
prompt_values={"prompt1": "input2"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert store.get("my_generator", "prompt1").decode() == "input2"
|
||||||
|
|||||||
Reference in New Issue
Block a user