test_vars: mock ask function instead of sys.stdin

This commit is contained in:
Jörg Thalheim
2024-11-26 12:06:42 +01:00
committed by Mic92
parent 4fdbadc7c5
commit 39db147e48
2 changed files with 12 additions and 4 deletions

View File

@@ -6,8 +6,13 @@ from clan_cli.errors import ClanError
log = logging.getLogger(__name__)
# This is for simulating user input in tests.
MOCK_PROMPT_RESPONSE = None
def ask(description: str, input_type: str) -> str:
if MOCK_PROMPT_RESPONSE:
return next(MOCK_PROMPT_RESPONSE)
if input_type == "line":
result = input(f"Enter the value for {description}: ")
elif input_type == "multiline":

View File

@@ -1,7 +1,6 @@
import json
import shutil
from dataclasses import dataclass
from io import StringIO
from pathlib import Path
import pytest
@@ -389,7 +388,9 @@ def test_prompt(
my_generator["script"] = "cat $prompts/prompt1 > $out/my_value"
flake.refresh()
monkeypatch.chdir(flake.path)
monkeypatch.setattr("sys.stdin", StringIO(input_value))
monkeypatch.setattr(
"clan_cli.vars.prompt.MOCK_PROMPT_RESPONSE", iter([input_value])
)
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
in_repo_store = in_repo.FactStore(
Machine(name="my_machine", flake=FlakeId(str(flake.path)))
@@ -497,7 +498,9 @@ def test_prompt_create_file(
flake.refresh()
monkeypatch.chdir(flake.path)
sops_setup.init()
monkeypatch.setattr("sys.stdin", StringIO("input1\ninput2\n"))
monkeypatch.setattr(
"clan_cli.vars.prompt.MOCK_PROMPT_RESPONSE", iter(["input1", "input2"])
)
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
sops_store = sops.SecretStore(
Machine(name="my_machine", flake=FlakeId(str(flake.path)))
@@ -520,7 +523,7 @@ def test_api_get_prompts(
my_generator["files"]["prompt1"]["secret"] = False
flake.refresh()
monkeypatch.chdir(flake.path)
monkeypatch.setattr("sys.stdin", StringIO("input1"))
monkeypatch.setattr("clan_cli.vars.prompt.MOCK_PROMPT_RESPONSE", iter(["input1"]))
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
machine = Machine(name="my_machine", flake=FlakeId(str(flake.path)))
api_prompts = get_prompts(machine)