diff --git a/pkgs/clan-cli/clan_cli/vms/inspect.py b/pkgs/clan-cli/clan_cli/vms/inspect.py index a2f5dc513..ada68be7c 100644 --- a/pkgs/clan-cli/clan_cli/vms/inspect.py +++ b/pkgs/clan-cli/clan_cli/vms/inspect.py @@ -30,7 +30,6 @@ class VmConfig: waypipe: bool = False - def __post_init__(self) -> None: if isinstance(self.flake_url, str): self.flake_url = FlakeId(self.flake_url) diff --git a/pkgs/clan-cli/clan_cli/vms/run.py b/pkgs/clan-cli/clan_cli/vms/run.py index 73dd2ecf2..30c5fe254 100644 --- a/pkgs/clan-cli/clan_cli/vms/run.py +++ b/pkgs/clan-cli/clan_cli/vms/run.py @@ -3,17 +3,23 @@ import importlib import json import logging import os -from contextlib import ExitStack +import socket +import subprocess +import time +from collections.abc import Iterator +from contextlib import ExitStack, contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from clan_cli.cmd import Log, run +from clan_cli.cmd import CmdOut, Log, handle_output, run from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.dirs import module_root, user_cache_dir, vm_state_dir -from clan_cli.errors import ClanError +from clan_cli.errors import ClanCmdError, ClanError from clan_cli.facts.generate import generate_facts from clan_cli.machines.machines import Machine from clan_cli.nix import nix_shell +from clan_cli.qemu.qga import QgaSession +from clan_cli.qemu.qmp import QEMUMonitorProtocol from .inspect import VmConfig, inspect_vm from .qemu import qemu_command @@ -107,14 +113,94 @@ def prepare_disk( return disk_img -def run_vm( +@contextmanager +def start_vm( + args: list[str], + packages: list[str], + extra_env: dict[str, str], + stdout: int | None = None, + stderr: int | None = None, +) -> Iterator[subprocess.Popen]: + env = os.environ.copy() + env.update(extra_env) + cmd = nix_shell(packages, args) + with subprocess.Popen(cmd, env=env, stdout=stdout, stderr=stderr) as process: + try: + yield process + finally: + process.terminate() + try: + # Fix me: This should in future properly shutdown the VM using qmp + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + +class QemuVm: + def __init__( + self, + machine: Machine, + process: subprocess.Popen, + ) -> None: + self.machine = machine + self.process = process + self.state_dir = vm_state_dir(self.machine.flake, self.machine.name) + self.qmp_socket_file = self.state_dir / "qmp.sock" + self.qga_socket_file = self.state_dir / "qga.sock" + + # wait for vm to be up then connect and return qmp instance + @contextmanager + def qmp_connect(self) -> Iterator[QEMUMonitorProtocol]: + with QEMUMonitorProtocol( + address=str(os.path.realpath(self.qmp_socket_file)), + ) as qmp: + qmp.connect() + yield qmp + + @contextmanager + def qga_connect(self, timeout_sec: float = 100) -> Iterator[QgaSession]: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + # try to reconnect a couple of times if connection refused + socket_file = os.path.realpath(self.qga_socket_file) + start_time = time.time() + while time.time() - start_time < timeout_sec: + try: + sock.connect(str(socket_file)) + except ConnectionRefusedError: + time.sleep(0.1) + else: + break + sock.connect(str(socket_file)) + yield QgaSession(sock) + finally: + sock.close() + + def wait_up(self, timeout_sec: float = 60) -> None: + start_time = time.time() + while time.time() - start_time < timeout_sec: + if self.process.poll() is not None: + msg = "VM failed to start. Qemu process exited with code {self.process.returncode}" + raise ClanError(msg) + if self.qmp_socket_file.exists(): + break + time.sleep(0.1) + + def wait_down(self) -> int: + return self.process.wait() + + +@contextmanager +def spawn_vm( vm: VmConfig, *, cachedir: Path | None = None, socketdir: Path | None = None, nix_options: list[str] | None = None, portmap: list[tuple[int, int]] | None = None, -) -> None: + stdout: int | None = None, + stderr: int | None = None, +) -> Iterator[QemuVm]: if portmap is None: portmap = [] if nix_options is None: @@ -141,7 +227,7 @@ def run_vm( # TODO: We should get this from the vm argument nixos_config = build_vm(machine, cachedir, nix_options) - state_dir = vm_state_dir(str(vm.flake_url), machine.name) + state_dir = vm_state_dir(vm.flake_url, machine.name) state_dir.mkdir(parents=True, exist_ok=True) # specify socket files for qmp and qga @@ -185,24 +271,71 @@ def run_vm( packages = ["nixpkgs#qemu"] - env = os.environ.copy() + extra_env = {} if vm.graphics and not vm.waypipe: packages.append("nixpkgs#virt-viewer") remote_viewer_mimetypes = module_root() / "vms" / "mimetypes" - env["XDG_DATA_DIRS"] = ( - f"{remote_viewer_mimetypes}:{env.get('XDG_DATA_DIRS', '')}" + extra_env["XDG_DATA_DIRS"] = ( + f"{remote_viewer_mimetypes}:{os.environ.get('XDG_DATA_DIRS', '')}" ) with ( start_waypipe(qemu_cmd.vsock_cid, f"[{vm.machine_name}] "), start_virtiofsd(virtiofsd_socket), + start_vm( + qemu_cmd.args, packages, extra_env, stdout=stdout, stderr=stderr + ) as process, ): - run( - nix_shell(packages, qemu_cmd.args), - env=env, - log=Log.BOTH, - error_msg=f"Could not start vm {machine}", - ) + qemu_vm = QemuVm(machine, process) + qemu_vm.wait_up() + + try: + yield qemu_vm + finally: + try: + with qemu_vm.qmp_connect() as qmp: + qmp.command("system_powerdown") + qemu_vm.wait_down() + except OSError: + pass + # TODO: add a timeout here instead of waiting indefinitely + + +def run_vm( + vm_config: VmConfig, + *, + cachedir: Path | None = None, + socketdir: Path | None = None, + nix_options: list[str] | None = None, + portmap: list[tuple[int, int]] | None = None, +) -> None: + with spawn_vm( + vm_config, + cachedir=cachedir, + socketdir=socketdir, + nix_options=nix_options, + portmap=portmap, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as vm: + stdout_buf, stderr_buf = handle_output(vm.process, Log.BOTH) + args: list[str] = vm.process.args # type: ignore[assignment] + cmd_out = CmdOut( + stdout=stdout_buf, + stderr=stderr_buf, + cwd=Path.cwd(), + command_list=args, + returncode=vm.process.returncode, + msg=f"Could not start vm {vm_config.machine_name}", + env={}, + ) + + if vm.process.returncode != 0: + raise ClanCmdError(cmd_out) + rc = vm.wait_down() + if rc != 0: + msg = f"VM exited with code {rc}" + raise ClanError(msg) def run_command( diff --git a/pkgs/clan-cli/tests/helpers/vms.py b/pkgs/clan-cli/tests/helpers/vms.py deleted file mode 100644 index 0b0247ea6..000000000 --- a/pkgs/clan-cli/tests/helpers/vms.py +++ /dev/null @@ -1,132 +0,0 @@ -import contextlib -import os -import socket -import sys -import threading -import traceback -from collections.abc import Iterator -from pathlib import Path -from time import sleep - -from clan_cli.dirs import vm_state_dir -from clan_cli.errors import ClanError -from clan_cli.qemu.qga import QgaSession -from clan_cli.qemu.qmp import QEMUMonitorProtocol - -from . import cli - - -def find_free_port() -> int: - """Find an unused localhost port from 1024-65535 and return it.""" - with contextlib.closing(socket.socket(type=socket.SOCK_STREAM)) as sock: - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -class VmThread(threading.Thread): - def __init__(self, machine_name: str, ssh_port: int | None = None) -> None: - super().__init__() - self.machine_name = machine_name - self.ssh_port = ssh_port - self.exception: Exception | None = None - self.daemon = True - - def run(self) -> None: - try: - cli.run( - ["vms", "run", self.machine_name, "--publish", f"{self.ssh_port}:22"] - ) - except Exception as ex: - # print exception details - print(traceback.format_exc(), file=sys.stderr) - print(sys.exc_info()[2], file=sys.stderr) - self.exception = ex - - -def run_vm_in_thread(machine_name: str, ssh_port: int | None = None) -> VmThread: - # runs machine and prints exceptions - if ssh_port is None: - ssh_port = find_free_port() - - vm_thread = VmThread(machine_name, ssh_port) - vm_thread.start() - return vm_thread - - -# wait for qmp socket to exist -def wait_vm_up(machine_name: str, vm: VmThread, flake_url: str | None = None) -> None: - if flake_url is None: - flake_url = str(Path.cwd()) - socket_file = vm_state_dir(flake_url, machine_name) / "qmp.sock" - timeout: float = 600 - while True: - if vm.exception: - msg = "VM failed to start" - raise ClanError(msg) from vm.exception - if timeout <= 0: - msg = f"qmp socket {socket_file} not found. Is the VM running?" - raise TimeoutError(msg) - if socket_file.exists(): - break - sleep(0.1) - timeout -= 0.1 - - -# wait for vm to be down by checking if qmp socket is down -def wait_vm_down(machine_name: str, vm: VmThread, flake_url: str | None = None) -> None: - if flake_url is None: - flake_url = str(Path.cwd()) - socket_file = vm_state_dir(flake_url, machine_name) / "qmp.sock" - timeout: float = 300 - while socket_file.exists(): - if vm.exception: - msg = "VM failed to start" - raise ClanError(msg) from vm.exception - if timeout <= 0: - msg = f"qmp socket {socket_file} still exists. Is the VM down?" - raise TimeoutError(msg) - sleep(0.1) - timeout -= 0.1 - - -# wait for vm to be up then connect and return qmp instance -@contextlib.contextmanager -def qmp_connect( - machine_name: str, vm: VmThread, flake_url: str | None = None -) -> Iterator[QEMUMonitorProtocol]: - if flake_url is None: - flake_url = str(Path.cwd()) - state_dir = vm_state_dir(flake_url, machine_name) - wait_vm_up(machine_name, vm, flake_url) - with QEMUMonitorProtocol( - address=str(os.path.realpath(state_dir / "qmp.sock")), - ) as qmp: - qmp.connect() - yield qmp - - -# wait for vm to be up then connect and return qga instance -@contextlib.contextmanager -def qga_connect( - machine_name: str, vm: VmThread, flake_url: str | None = None -) -> Iterator[QgaSession]: - if flake_url is None: - flake_url = str(Path.cwd()) - state_dir = vm_state_dir(flake_url, machine_name) - wait_vm_up(machine_name, vm, flake_url) - - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - # try to reconnect a couple of times if connection refused - socket_file = os.path.realpath(state_dir / "qga.sock") - for _ in range(100): - try: - sock.connect(str(socket_file)) - except ConnectionRefusedError: - sleep(0.1) - else: - break - sock.connect(str(socket_file)) - yield QgaSession(sock) - finally: - sock.close() diff --git a/pkgs/clan-cli/tests/test_vars_deployment.py b/pkgs/clan-cli/tests/test_vars_deployment.py index 22296909b..d9d127497 100644 --- a/pkgs/clan-cli/tests/test_vars_deployment.py +++ b/pkgs/clan-cli/tests/test_vars_deployment.py @@ -1,14 +1,17 @@ import json +from contextlib import ExitStack from pathlib import Path import pytest from age_keys import SopsSetup from clan_cli import cmd +from clan_cli.clan_uri import FlakeId +from clan_cli.machines.machines import Machine from clan_cli.nix import nix_eval, run +from clan_cli.vms.run import inspect_vm, spawn_vm from fixtures_flakes import generate_flake from helpers import cli from helpers.nixos_config import nested_dict -from helpers.vms import qga_connect, run_vm_in_thread, wait_vm_down from root import CLAN_CORE @@ -61,9 +64,9 @@ def test_vm_deployment( flake_template=CLAN_CORE / "templates" / "minimal", machine_configs={"m1_machine": machine1_config, "m2_machine": machine2_config}, ) - monkeypatch.chdir(flake.path) + sops_setup.init() - cli.run(["vars", "generate"]) + cli.run(["vars", "generate", "--flake", str(flake.path)]) # check sops secrets not empty for machine in ["m1_machine", "m2_machine"]: sops_secrets = json.loads( @@ -94,13 +97,15 @@ def test_vm_deployment( ).stdout.strip() assert "no-such-path" not in shared_secret_path # run nix flake lock - cmd.run(["nix", "flake", "lock"]) - vm_m1 = run_vm_in_thread("m1_machine") - vm_m2 = run_vm_in_thread("m2_machine") - with ( - qga_connect("m1_machine", vm_m1) as qga_m1, - qga_connect("m2_machine", vm_m2) as qga_m2, - ): + cmd.run(["nix", "flake", "lock"], cwd=flake.path) + + vm1_config = inspect_vm(machine=Machine("m1_machine", FlakeId(str(flake.path)))) + vm2_config = inspect_vm(machine=Machine("m2_machine", FlakeId(str(flake.path)))) + with ExitStack() as stack: + vm1 = stack.enter_context(spawn_vm(vm1_config)) + vm2 = stack.enter_context(spawn_vm(vm2_config)) + qga_m1 = stack.enter_context(vm1.qga_connect()) + qga_m2 = stack.enter_context(vm2.qga_connect()) # check my_secret is deployed _, out, _ = qga_m1.run( "cat /run/secrets/vars/m1_generator/my_secret", check=True @@ -122,9 +127,3 @@ def test_vm_deployment( check=False, ) assert returncode != 0 - qga_m1.exec_cmd("poweroff") - qga_m2.exec_cmd("poweroff") - wait_vm_down("m1_machine", vm_m1) - wait_vm_down("m2_machine", vm_m2) - vm_m1.join() - vm_m2.join() diff --git a/pkgs/clan-cli/tests/test_vms_cli.py b/pkgs/clan-cli/tests/test_vms_cli.py index 7944fecf0..57d9faf7a 100644 --- a/pkgs/clan-cli/tests/test_vms_cli.py +++ b/pkgs/clan-cli/tests/test_vms_cli.py @@ -2,10 +2,12 @@ from pathlib import Path from typing import TYPE_CHECKING import pytest +from clan_cli.clan_uri import FlakeId +from clan_cli.machines.machines import Machine +from clan_cli.vms.run import inspect_vm, spawn_vm from fixtures_flakes import FlakeForTest, generate_flake from helpers import cli from helpers.nixos_config import nested_dict -from helpers.vms import qga_connect, qmp_connect, run_vm_in_thread, wait_vm_down from root import CLAN_CORE from stdout import CaptureOutput @@ -84,30 +86,18 @@ def test_vm_persistence( machine_configs=config, ) - monkeypatch.chdir(flake.path) + vm_config = inspect_vm(machine=Machine("my_machine", FlakeId(str(flake.path)))) - vm = run_vm_in_thread("my_machine") - - # wait for the VM to start and connect qga - with qga_connect("my_machine", vm) as qga: + with spawn_vm(vm_config) as vm, vm.qga_connect() as qga: # create state via qmp command instead of systemd service qga.run("echo 'dream2nix' > /var/my-state/root", check=True) qga.run("echo 'dream2nix' > /var/my-state/test", check=True) qga.run("chown test /var/my-state/test", check=True) qga.run("chown test /var/user-state", check=True) qga.run("touch /var/my-state/rebooting", check=True) - qga.exec_cmd("poweroff") - - # wait for socket to be down (systemd service 'poweroff' rebooting machine) - wait_vm_down("my_machine", vm) - - vm.join() ## start vm again - vm = run_vm_in_thread("my_machine") - - ## connect second time - with qga_connect("my_machine", vm) as qga: + with spawn_vm(vm_config) as vm, vm.qga_connect() as qga: # check state exists qga.run("cat /var/my-state/test", check=True) # ensure root file is owned by root @@ -131,7 +121,3 @@ def test_vm_persistence( "systemctl --failed | tee /tmp/yolo | grep -q '0 loaded units listed' || ( cat /tmp/yolo && false )" ) assert exitcode == 0, out - - with qmp_connect("my_machine", vm) as qmp: - qmp.command("system_powerdown") - vm.join()