clan-cli: Fix ignored debug flag in clan vms run, refactor Host.run to use RunOpts

This commit is contained in:
Qubasa
2024-12-03 16:01:51 +01:00
parent e0cedb956a
commit 570bceff4e
14 changed files with 92 additions and 151 deletions

View File

@@ -102,6 +102,7 @@ def register_common_flags(parser: argparse.ArgumentParser) -> None:
for _choice, child_parser in action.choices.items(): for _choice, child_parser in action.choices.items():
has_subparsers = True has_subparsers = True
register_common_flags(child_parser) register_common_flags(child_parser)
if not has_subparsers: if not has_subparsers:
add_common_flags(parser) add_common_flags(parser)

View File

@@ -2,7 +2,7 @@ import argparse
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from clan_cli.cmd import Log from clan_cli.cmd import Log, RunOpts
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
complete_backup_providers_for_machine, complete_backup_providers_for_machine,
@@ -23,8 +23,7 @@ def list_provider(machine: Machine, provider: str) -> list[Backup]:
backup_metadata = json.loads(machine.eval_nix("config.clan.core.backups")) backup_metadata = json.loads(machine.eval_nix("config.clan.core.backups"))
proc = machine.target_host.run( proc = machine.target_host.run(
[backup_metadata["providers"][provider]["list"]], [backup_metadata["providers"][provider]["list"]],
log=Log.STDERR, RunOpts(log=Log.STDERR, check=False),
check=False,
) )
if proc.returncode != 0: if proc.returncode != 0:
# TODO this should be a warning, only raise exception if no providers succeed # TODO this should be a warning, only raise exception if no providers succeed

View File

@@ -1,7 +1,7 @@
import argparse import argparse
import json import json
from clan_cli.cmd import Log from clan_cli.cmd import Log, RunOpts
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
complete_backup_providers_for_machine, complete_backup_providers_for_machine,
@@ -28,7 +28,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if pre_restore := backup_folders[service]["preRestoreCommand"]: if pre_restore := backup_folders[service]["preRestoreCommand"]:
proc = machine.target_host.run( proc = machine.target_host.run(
[pre_restore], [pre_restore],
log=Log.STDERR, RunOpts(log=Log.STDERR),
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:
@@ -37,7 +37,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
proc = machine.target_host.run( proc = machine.target_host.run(
[backup_metadata["providers"][provider]["restore"]], [backup_metadata["providers"][provider]["restore"]],
log=Log.STDERR, RunOpts(log=Log.STDERR),
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:
@@ -47,7 +47,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if post_restore := backup_folders[service]["postRestoreCommand"]: if post_restore := backup_folders[service]["postRestoreCommand"]:
proc = machine.target_host.run( proc = machine.target_host.run(
[post_restore], [post_restore],
log=Log.STDERR, RunOpts(log=Log.STDERR),
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:

View File

@@ -39,10 +39,10 @@ class ClanCmdTimeoutError(ClanError):
class Log(Enum): class Log(Enum):
NONE = 0
STDERR = 1 STDERR = 1
STDOUT = 2 STDOUT = 2
BOTH = 3 BOTH = 3
NONE = 4
@dataclass @dataclass
@@ -276,12 +276,12 @@ def run(
else: else:
filtered_input = options.input.decode("ascii", "replace") filtered_input = options.input.decode("ascii", "replace")
print_trace( print_trace(
f"$: echo '{filtered_input}' | {indent_command(cmd)}", f"echo '{filtered_input}' | {indent_command(cmd)}",
cmdlog, cmdlog,
options.prefix, options.prefix,
) )
elif cmdlog.isEnabledFor(logging.DEBUG): elif cmdlog.isEnabledFor(logging.DEBUG):
print_trace(f"$: {indent_command(cmd)}", cmdlog, options.prefix) print_trace(f"{indent_command(cmd)}", cmdlog, options.prefix)
start = timeit.default_timer() start = timeit.default_timer()
with ExitStack() as stack: with ExitStack() as stack:
@@ -343,8 +343,7 @@ def run_no_output(
*, *,
env: dict[str, str] | None = None, env: dict[str, str] | None = None,
cwd: Path | None = None, cwd: Path | None = None,
log: Log = Log.STDERR, log: Log = Log.NONE,
logger: logging.Logger = cmdlog,
prefix: str | None = None, prefix: str | None = None,
check: bool = True, check: bool = True,
error_msg: str | None = None, error_msg: str | None = None,
@@ -355,20 +354,22 @@ def run_no_output(
Like run, but automatically suppresses all output, if not in DEBUG log level. Like run, but automatically suppresses all output, if not in DEBUG log level.
If in DEBUG log level the stdout of commands will be shown. If in DEBUG log level the stdout of commands will be shown.
""" """
if cwd is None: opts = RunOpts(
cwd = Path.cwd() env=env,
if logger.isEnabledFor(logging.DEBUG): cwd=cwd,
return run(cmd, RunOpts(env=env, log=log, check=check, error_msg=error_msg)) log=log,
log = Log.NONE check=check,
error_msg=error_msg,
needs_user_terminal=needs_user_terminal,
shell=shell,
prefix=prefix,
)
if cmdlog.isEnabledFor(logging.DEBUG):
opts.log = log if log.value > Log.STDERR.value else Log.STDERR
else:
opts.log = log
return run( return run(
cmd, cmd,
RunOpts( opts,
env=env,
log=log,
check=check,
prefix=prefix,
error_msg=error_msg,
needs_user_terminal=needs_user_terminal,
shell=shell,
),
) )

View File

@@ -3,7 +3,7 @@ import subprocess
from pathlib import Path from pathlib import Path
from typing import override from typing import override
from clan_cli.cmd import Log from clan_cli.cmd import Log, RunOpts
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell from clan_cli.nix import nix_shell
@@ -98,8 +98,7 @@ class SecretStore(SecretStoreBase):
remote_hash = self.machine.target_host.run( remote_hash = self.machine.target_host.run(
# TODO get the path to the secrets from the machine # TODO get the path to the secrets from the machine
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"], ["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
log=Log.STDERR, RunOpts(log=Log.STDERR, check=False),
check=False,
).stdout.strip() ).stdout.strip()
if not remote_hash: if not remote_hash:

View File

@@ -154,15 +154,13 @@ def deploy_machine(machines: MachineGroup) -> None:
env = host.nix_ssh_env(None) env = host.nix_ssh_env(None)
ret = host.run( ret = host.run(
switch_cmd, switch_cmd,
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env, extra_env=env,
check=False,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
) )
ret = host.run( ret = host.run(
switch_cmd, switch_cmd,
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env, extra_env=env,
check=False,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
) )
# if the machine is mobile, we retry to deploy with the mobile workaround method # if the machine is mobile, we retry to deploy with the mobile workaround method
@@ -170,13 +168,17 @@ def deploy_machine(machines: MachineGroup) -> None:
if is_mobile and ret.returncode != 0: if is_mobile and ret.returncode != 0:
log.info("Mobile machine detected, applying workaround deployment method") log.info("Mobile machine detected, applying workaround deployment method")
ret = host.run( ret = host.run(
test_cmd, extra_env=env, msg_color=MsgColor(stderr=AnsiColor.DEFAULT) test_cmd,
RunOpts(msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
) )
# retry nixos-rebuild switch if the first attempt failed # retry nixos-rebuild switch if the first attempt failed
elif ret.returncode != 0: elif ret.returncode != 0:
ret = host.run( ret = host.run(
switch_cmd, extra_env=env, msg_color=MsgColor(stderr=AnsiColor.DEFAULT) switch_cmd,
RunOpts(msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
) )
if len(machines.group.hosts) > 1: if len(machines.group.hosts) > 1:

View File

@@ -1,16 +1,15 @@
# Adapted from https://github.com/numtide/deploykit # Adapted from https://github.com/numtide/deploykit
import logging import logging
import math
import os import os
import shlex import shlex
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from shlex import quote from shlex import quote
from typing import IO, Any from typing import Any
from clan_cli.cmd import CmdOut, Log, MsgColor, RunOpts, run from clan_cli.cmd import CmdOut, RunOpts, run
from clan_cli.colors import AnsiColor from clan_cli.colors import AnsiColor
from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
cmdlog = logging.getLogger(__name__) cmdlog = logging.getLogger(__name__)
@@ -78,18 +77,11 @@ class Host:
def run( def run(
self, self,
cmd: list[str], cmd: list[str],
stdout: IO[bytes] | None = None, opts: RunOpts | None = None,
stderr: IO[bytes] | None = None,
become_root: bool = False, become_root: bool = False,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
msg_color: MsgColor | None = None, verbose_ssh: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> CmdOut: ) -> CmdOut:
""" """
Command to run on the host via ssh Command to run on the host via ssh
@@ -107,6 +99,16 @@ class Host:
for k, v in extra_env.items(): for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}") env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
if opts is None:
opts = RunOpts()
else:
opts.needs_user_terminal = True
opts.prefix = self.command_prefix
if opts.cwd is not None:
msg = "cwd is not supported for remote commands"
raise ClanError(msg)
# Build a pretty command for logging # Build a pretty command for logging
displayed_cmd = "" displayed_cmd = ""
export_cmd = "" export_cmd = ""
@@ -124,8 +126,9 @@ class Host:
# Build the ssh command # Build the ssh command
bash_cmd = export_cmd bash_cmd = export_cmd
if shell: if opts.shell:
bash_cmd += " ".join(cmd) bash_cmd += " ".join(cmd)
opts.shell = False
else: else:
bash_cmd += 'exec "$@"' bash_cmd += 'exec "$@"'
# FIXME we assume bash to be present here? Should be documented... # FIXME we assume bash to be present here? Should be documented...
@@ -135,19 +138,6 @@ class Host:
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}", f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}",
] ]
opts = RunOpts(
shell=False,
stdout=stdout,
stderr=stderr,
log=log,
cwd=cwd,
check=check,
prefix=self.command_prefix,
timeout=timeout,
msg_color=msg_color,
needs_user_terminal=True, # ssh asks for a password
)
# Run the ssh command # Run the ssh command
return run(ssh_cmd, opts) return run(ssh_cmd, opts)

View File

@@ -1,12 +1,10 @@
import logging import logging
import math
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from threading import Thread from threading import Thread
from typing import IO, Any from typing import Any
from clan_cli.cmd import Log, RunOpts from clan_cli.cmd import RunOpts
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.ssh import T from clan_cli.ssh import T
from clan_cli.ssh.host import Host from clan_cli.ssh.host import Host
@@ -42,8 +40,6 @@ class HostGroup:
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
) -> None: ) -> None:
if extra_env is None:
extra_env = {}
try: try:
proc = host.run_local( proc = host.run_local(
cmd, cmd,
@@ -57,37 +53,20 @@ class HostGroup:
def _run_remote( def _run_remote(
self, self,
cmd: list[str], cmd: list[str],
opts: RunOpts,
host: Host, host: Host,
results: Results, results: Results,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False, tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> None: ) -> None:
if cwd is not None:
msg = "cwd is not supported for remote commands"
raise ClanError(msg)
if extra_env is None:
extra_env = {}
try: try:
proc = host.run( proc = host.run(
cmd, cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env, extra_env=extra_env,
cwd=cwd,
check=check,
verbose_ssh=verbose_ssh, verbose_ssh=verbose_ssh,
timeout=timeout, opts=opts,
tty=tty, tty=tty,
shell=shell,
log=log,
) )
results.append(HostResult(host, proc)) results.append(HostResult(host, proc))
except Exception as e: except Exception as e:
@@ -116,12 +95,12 @@ class HostGroup:
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
) -> Results: ) -> Results:
if opts is None:
opts = RunOpts()
if extra_env is None:
extra_env = {}
results: Results = [] results: Results = []
threads = [] threads = []
if opts is None:
opts = RunOpts()
for host in self.hosts: for host in self.hosts:
if local: if local:
thread = Thread( thread = Thread(
@@ -143,16 +122,10 @@ class HostGroup:
"results": results, "results": results,
"cmd": cmd, "cmd": cmd,
"host": host, "host": host,
"stdout": opts.stdout, "opts": opts,
"stderr": opts.stderr,
"extra_env": extra_env, "extra_env": extra_env,
"cwd": opts.cwd,
"check": opts.check,
"timeout": opts.timeout,
"verbose_ssh": verbose_ssh, "verbose_ssh": verbose_ssh,
"tty": tty, "tty": tty,
"shell": opts.shell,
"log": opts.log,
}, },
) )
@@ -170,33 +143,17 @@ class HostGroup:
def run( def run(
self, self,
cmd: list[str], cmd: list[str],
stdout: IO[bytes] | None = None, opts: RunOpts | None = None,
stderr: IO[bytes] | None = None, *,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False, tty: bool = False,
log: Log = Log.BOTH,
shell: bool = False,
) -> Results: ) -> Results:
""" """
Command to run on the remote host via ssh Command to run on the remote host via ssh
@return a lists of tuples containing Host and the result of the command for this Host @return a lists of tuples containing Host and the result of the command for this Host
""" """
if extra_env is None:
extra_env = {}
opts = RunOpts(
shell=shell,
stdout=stdout,
stderr=stderr,
log=log,
timeout=timeout,
cwd=cwd,
check=check,
)
return self._run( return self._run(
cmd, cmd,
opts, opts,
@@ -208,31 +165,16 @@ class HostGroup:
def run_local( def run_local(
self, self,
cmd: list[str], cmd: list[str],
stdout: IO[bytes] | None = None, opts: RunOpts | None = None,
stderr: IO[bytes] | None = None, *,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results: ) -> Results:
""" """
Command to run locally for each host in the group in parallel Command to run locally for each host in the group in parallel
@return a lists of tuples containing Host and the result of the command for this Host @return a lists of tuples containing Host and the result of the command for this Host
""" """
if extra_env is None:
extra_env = {}
opts = RunOpts(
stdout=stdout,
stderr=stderr,
cwd=cwd,
check=check,
timeout=timeout,
shell=shell,
log=log,
)
return self._run( return self._run(
cmd, cmd,
opts, opts,

View File

@@ -62,7 +62,8 @@ def upload(
str(remote_dest), str(remote_dest),
";", ";",
"mkdir", "mkdir",
f"--mode={dir_mode:o}", "-m",
f"{dir_mode:o}",
"-p", "-p",
str(remote_dest), str(remote_dest),
"&&", "&&",

View File

@@ -140,8 +140,7 @@ class SecretStore(SecretStoreBase):
"cat", "cat",
f"{self.machine.deployment["password-store"]["secretLocation"]}/.pass_info", f"{self.machine.deployment["password-store"]["secretLocation"]}/.pass_info",
], ],
log=Log.STDERR, RunOpts(log=Log.STDERR, check=False),
check=False,
).stdout.strip() ).stdout.strip()
if not remote_hash: if not remote_hash:

View File

@@ -417,7 +417,8 @@ def register_run_parser(parser: argparse.ArgumentParser) -> None:
default=False, default=False,
) )
parser.add_argument( parser.add_argument(
"command", "--command",
"-c",
nargs=argparse.REMAINDER, nargs=argparse.REMAINDER,
help="command to run in the vm", help="command to run in the vm",
) )

View File

@@ -7,11 +7,15 @@ hosts = HostGroup([Host("some_host")])
def test_run_environment() -> None: def test_run_environment() -> None:
p2 = hosts.run_local( p2 = hosts.run_local(
["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR ["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
) )
assert p2[0].result.stdout == "true\n" assert p2[0].result.stdout == "true\n"
p3 = hosts.run_local(["env"], extra_env={"env_var": "true"}, log=Log.STDERR) p3 = hosts.run_local(
["env"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"}
)
assert "env_var=true" in p3[0].result.stdout assert "env_var=true" in p3[0].result.stdout
@@ -21,7 +25,7 @@ def test_run_local() -> None:
def test_timeout() -> None: def test_timeout() -> None:
try: try:
hosts.run_local(["sleep", "10"], timeout=0.01) hosts.run_local(["sleep", "10"], RunOpts(timeout=0.01))
except Exception: except Exception:
pass pass
else: else:
@@ -40,7 +44,7 @@ def test_run_function() -> None:
def test_run_exception() -> None: def test_run_exception() -> None:
try: try:
hosts.run_local(["exit 1"], shell=True) hosts.run_local(["exit 1"], RunOpts(shell=True))
except Exception: except Exception:
pass pass
else: else:
@@ -62,5 +66,5 @@ def test_run_function_exception() -> None:
def test_run_local_non_shell() -> None: def test_run_local_non_shell() -> None:
p2 = hosts.run_local(["echo", "1"], log=Log.STDERR) p2 = hosts.run_local(["echo", "1"], RunOpts(log=Log.STDERR))
assert p2[0].result.stdout == "1\n" assert p2[0].result.stdout == "1\n"

View File

@@ -21,21 +21,23 @@ def test_parse_ipv6() -> None:
def test_run(host_group: HostGroup) -> None: def test_run(host_group: HostGroup) -> None:
proc = host_group.run_local(["echo", "hello"], log=Log.STDERR) proc = host_group.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
assert proc[0].result.stdout == "hello\n" assert proc[0].result.stdout == "hello\n"
def test_run_environment(host_group: HostGroup) -> None: def test_run_environment(host_group: HostGroup) -> None:
p1 = host_group.run( p1 = host_group.run(
["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR ["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
) )
assert p1[0].result.stdout == "true\n" assert p1[0].result.stdout == "true\n"
p2 = host_group.run(["env"], log=Log.STDERR, extra_env={"env_var": "true"}) p2 = host_group.run(["env"], RunOpts(log=Log.STDERR), extra_env={"env_var": "true"})
assert "env_var=true" in p2[0].result.stdout assert "env_var=true" in p2[0].result.stdout
def test_run_no_shell(host_group: HostGroup) -> None: def test_run_no_shell(host_group: HostGroup) -> None:
proc = host_group.run(["echo", "$hello"], log=Log.STDERR) proc = host_group.run(["echo", "$hello"], RunOpts(log=Log.STDERR))
assert proc[0].result.stdout == "$hello\n" assert proc[0].result.stdout == "$hello\n"
@@ -50,7 +52,7 @@ def test_run_function(host_group: HostGroup) -> None:
def test_timeout(host_group: HostGroup) -> None: def test_timeout(host_group: HostGroup) -> None:
try: try:
host_group.run_local(["sleep", "10"], timeout=0.01) host_group.run_local(["sleep", "10"], RunOpts(timeout=0.01))
except Exception: except Exception:
pass pass
else: else:
@@ -59,11 +61,11 @@ def test_timeout(host_group: HostGroup) -> None:
def test_run_exception(host_group: HostGroup) -> None: def test_run_exception(host_group: HostGroup) -> None:
r = host_group.run(["exit 1"], check=False, shell=True) r = host_group.run(["exit 1"], RunOpts(check=False, shell=True))
assert r[0].result.returncode == 1 assert r[0].result.returncode == 1
try: try:
host_group.run(["exit 1"], shell=True) host_group.run(["exit 1"], RunOpts(shell=True))
except Exception: except Exception:
pass pass
else: else:

View File

@@ -51,7 +51,7 @@ def test_run(
"user1", "user1",
] ]
) )
cli.run(["vms", "run", "--no-block", "vm1", "shutdown", "-h", "now"]) cli.run(["vms", "run", "--no-block", "vm1", "-c", "shutdown", "-h", "now"])
@pytest.mark.skipif(no_kvm, reason="Requires KVM") @pytest.mark.skipif(no_kvm, reason="Requires KVM")