clan-cli: Fix ignored debug flag in clan vms run, refactor Host.run to use RunOpts

This commit is contained in:
Qubasa
2024-12-03 16:01:51 +01:00
parent f033a193d5
commit 164c621dc0
14 changed files with 92 additions and 151 deletions

View File

@@ -102,6 +102,7 @@ def register_common_flags(parser: argparse.ArgumentParser) -> None:
for _choice, child_parser in action.choices.items():
has_subparsers = True
register_common_flags(child_parser)
if not has_subparsers:
add_common_flags(parser)

View File

@@ -2,7 +2,7 @@ import argparse
import json
from dataclasses import dataclass
from clan_cli.cmd import Log
from clan_cli.cmd import Log, RunOpts
from clan_cli.completions import (
add_dynamic_completer,
complete_backup_providers_for_machine,
@@ -23,8 +23,7 @@ def list_provider(machine: Machine, provider: str) -> list[Backup]:
backup_metadata = json.loads(machine.eval_nix("config.clan.core.backups"))
proc = machine.target_host.run(
[backup_metadata["providers"][provider]["list"]],
log=Log.STDERR,
check=False,
RunOpts(log=Log.STDERR, check=False),
)
if proc.returncode != 0:
# TODO this should be a warning, only raise exception if no providers succeed

View File

@@ -1,7 +1,7 @@
import argparse
import json
from clan_cli.cmd import Log
from clan_cli.cmd import Log, RunOpts
from clan_cli.completions import (
add_dynamic_completer,
complete_backup_providers_for_machine,
@@ -28,7 +28,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if pre_restore := backup_folders[service]["preRestoreCommand"]:
proc = machine.target_host.run(
[pre_restore],
log=Log.STDERR,
RunOpts(log=Log.STDERR),
extra_env=env,
)
if proc.returncode != 0:
@@ -37,7 +37,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
proc = machine.target_host.run(
[backup_metadata["providers"][provider]["restore"]],
log=Log.STDERR,
RunOpts(log=Log.STDERR),
extra_env=env,
)
if proc.returncode != 0:
@@ -47,7 +47,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if post_restore := backup_folders[service]["postRestoreCommand"]:
proc = machine.target_host.run(
[post_restore],
log=Log.STDERR,
RunOpts(log=Log.STDERR),
extra_env=env,
)
if proc.returncode != 0:

View File

@@ -39,10 +39,10 @@ class ClanCmdTimeoutError(ClanError):
class Log(Enum):
NONE = 0
STDERR = 1
STDOUT = 2
BOTH = 3
NONE = 4
@dataclass
@@ -276,12 +276,12 @@ def run(
else:
filtered_input = options.input.decode("ascii", "replace")
print_trace(
f"$: echo '{filtered_input}' | {indent_command(cmd)}",
f"echo '{filtered_input}' | {indent_command(cmd)}",
cmdlog,
options.prefix,
)
elif cmdlog.isEnabledFor(logging.DEBUG):
print_trace(f"$: {indent_command(cmd)}", cmdlog, options.prefix)
print_trace(f"{indent_command(cmd)}", cmdlog, options.prefix)
start = timeit.default_timer()
with ExitStack() as stack:
@@ -343,8 +343,7 @@ def run_no_output(
*,
env: dict[str, str] | None = None,
cwd: Path | None = None,
log: Log = Log.STDERR,
logger: logging.Logger = cmdlog,
log: Log = Log.NONE,
prefix: str | None = None,
check: bool = True,
error_msg: str | None = None,
@@ -355,20 +354,22 @@ def run_no_output(
Like run, but automatically suppresses all output, if not in DEBUG log level.
If in DEBUG log level the stdout of commands will be shown.
"""
if cwd is None:
cwd = Path.cwd()
if logger.isEnabledFor(logging.DEBUG):
return run(cmd, RunOpts(env=env, log=log, check=check, error_msg=error_msg))
log = Log.NONE
return run(
cmd,
RunOpts(
opts = RunOpts(
env=env,
cwd=cwd,
log=log,
check=check,
prefix=prefix,
error_msg=error_msg,
needs_user_terminal=needs_user_terminal,
shell=shell,
),
prefix=prefix,
)
if cmdlog.isEnabledFor(logging.DEBUG):
opts.log = log if log.value > Log.STDERR.value else Log.STDERR
else:
opts.log = log
return run(
cmd,
opts,
)

View File

@@ -3,7 +3,7 @@ import subprocess
from pathlib import Path
from typing import override
from clan_cli.cmd import Log
from clan_cli.cmd import Log, RunOpts
from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell
@@ -98,8 +98,7 @@ class SecretStore(SecretStoreBase):
remote_hash = self.machine.target_host.run(
# TODO get the path to the secrets from the machine
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
log=Log.STDERR,
check=False,
RunOpts(log=Log.STDERR, check=False),
).stdout.strip()
if not remote_hash:

View File

@@ -154,15 +154,13 @@ def deploy_machine(machines: MachineGroup) -> None:
env = host.nix_ssh_env(None)
ret = host.run(
switch_cmd,
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
check=False,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
)
ret = host.run(
switch_cmd,
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
check=False,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
)
# if the machine is mobile, we retry to deploy with the mobile workaround method
@@ -170,13 +168,17 @@ def deploy_machine(machines: MachineGroup) -> None:
if is_mobile and ret.returncode != 0:
log.info("Mobile machine detected, applying workaround deployment method")
ret = host.run(
test_cmd, extra_env=env, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)
test_cmd,
RunOpts(msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
)
# retry nixos-rebuild switch if the first attempt failed
elif ret.returncode != 0:
ret = host.run(
switch_cmd, extra_env=env, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)
switch_cmd,
RunOpts(msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
)
if len(machines.group.hosts) > 1:

View File

@@ -1,16 +1,15 @@
# Adapted from https://github.com/numtide/deploykit
import logging
import math
import os
import shlex
from dataclasses import dataclass, field
from pathlib import Path
from shlex import quote
from typing import IO, Any
from typing import Any
from clan_cli.cmd import CmdOut, Log, MsgColor, RunOpts, run
from clan_cli.cmd import CmdOut, RunOpts, run
from clan_cli.colors import AnsiColor
from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck
cmdlog = logging.getLogger(__name__)
@@ -78,18 +77,11 @@ class Host:
def run(
self,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
opts: RunOpts | None = None,
become_root: bool = False,
extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
msg_color: MsgColor | None = None,
shell: bool = False,
log: Log = Log.BOTH,
verbose_ssh: bool = False,
) -> CmdOut:
"""
Command to run on the host via ssh
@@ -107,6 +99,16 @@ class Host:
for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
if opts is None:
opts = RunOpts()
else:
opts.needs_user_terminal = True
opts.prefix = self.command_prefix
if opts.cwd is not None:
msg = "cwd is not supported for remote commands"
raise ClanError(msg)
# Build a pretty command for logging
displayed_cmd = ""
export_cmd = ""
@@ -124,8 +126,9 @@ class Host:
# Build the ssh command
bash_cmd = export_cmd
if shell:
if opts.shell:
bash_cmd += " ".join(cmd)
opts.shell = False
else:
bash_cmd += 'exec "$@"'
# FIXME we assume bash to be present here? Should be documented...
@@ -135,19 +138,6 @@ class Host:
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}",
]
opts = RunOpts(
shell=False,
stdout=stdout,
stderr=stderr,
log=log,
cwd=cwd,
check=check,
prefix=self.command_prefix,
timeout=timeout,
msg_color=msg_color,
needs_user_terminal=True, # ssh asks for a password
)
# Run the ssh command
return run(ssh_cmd, opts)

View File

@@ -1,12 +1,10 @@
import logging
import math
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from typing import IO, Any
from typing import Any
from clan_cli.cmd import Log, RunOpts
from clan_cli.cmd import RunOpts
from clan_cli.errors import ClanError
from clan_cli.ssh import T
from clan_cli.ssh.host import Host
@@ -42,8 +40,6 @@ class HostGroup:
verbose_ssh: bool = False,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run_local(
cmd,
@@ -57,37 +53,20 @@ class HostGroup:
def _run_remote(
self,
cmd: list[str],
opts: RunOpts,
host: Host,
results: Results,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> None:
if cwd is not None:
msg = "cwd is not supported for remote commands"
raise ClanError(msg)
if extra_env is None:
extra_env = {}
try:
proc = host.run(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
verbose_ssh=verbose_ssh,
timeout=timeout,
opts=opts,
tty=tty,
shell=shell,
log=log,
)
results.append(HostResult(host, proc))
except Exception as e:
@@ -116,12 +95,12 @@ class HostGroup:
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
if opts is None:
opts = RunOpts()
if extra_env is None:
extra_env = {}
results: Results = []
threads = []
if opts is None:
opts = RunOpts()
for host in self.hosts:
if local:
thread = Thread(
@@ -143,16 +122,10 @@ class HostGroup:
"results": results,
"cmd": cmd,
"host": host,
"stdout": opts.stdout,
"stderr": opts.stderr,
"opts": opts,
"extra_env": extra_env,
"cwd": opts.cwd,
"check": opts.check,
"timeout": opts.timeout,
"verbose_ssh": verbose_ssh,
"tty": tty,
"shell": opts.shell,
"log": opts.log,
},
)
@@ -170,33 +143,17 @@ class HostGroup:
def run(
self,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
opts: RunOpts | None = None,
*,
extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
log: Log = Log.BOTH,
shell: bool = False,
) -> Results:
"""
Command to run on the remote host via ssh
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
opts = RunOpts(
shell=shell,
stdout=stdout,
stderr=stderr,
log=log,
timeout=timeout,
cwd=cwd,
check=check,
)
return self._run(
cmd,
opts,
@@ -208,31 +165,16 @@ class HostGroup:
def run_local(
self,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
opts: RunOpts | None = None,
*,
extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results:
"""
Command to run locally for each host in the group in parallel
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
opts = RunOpts(
stdout=stdout,
stderr=stderr,
cwd=cwd,
check=check,
timeout=timeout,
shell=shell,
log=log,
)
return self._run(
cmd,
opts,

View File

@@ -62,7 +62,8 @@ def upload(
str(remote_dest),
";",
"mkdir",
f"--mode={dir_mode:o}",
"-m",
f"{dir_mode:o}",
"-p",
str(remote_dest),
"&&",

View File

@@ -140,8 +140,7 @@ class SecretStore(SecretStoreBase):
"cat",
f"{self.machine.deployment["password-store"]["secretLocation"]}/.pass_info",
],
log=Log.STDERR,
check=False,
RunOpts(log=Log.STDERR, check=False),
).stdout.strip()
if not remote_hash:

View File

@@ -417,7 +417,8 @@ def register_run_parser(parser: argparse.ArgumentParser) -> None:
default=False,
)
parser.add_argument(
"command",
"--command",
"-c",
nargs=argparse.REMAINDER,
help="command to run in the vm",
)

View File

@@ -7,11 +7,15 @@ hosts = HostGroup([Host("some_host")])
def test_run_environment() -> None:
p2 = hosts.run_local(
["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR
["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert p2[0].result.stdout == "true\n"
p3 = hosts.run_local(["env"], extra_env={"env_var": "true"}, log=Log.STDERR)
p3 = hosts.run_local(
["env"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"}
)
assert "env_var=true" in p3[0].result.stdout
@@ -21,7 +25,7 @@ def test_run_local() -> None:
def test_timeout() -> None:
try:
hosts.run_local(["sleep", "10"], timeout=0.01)
hosts.run_local(["sleep", "10"], RunOpts(timeout=0.01))
except Exception:
pass
else:
@@ -40,7 +44,7 @@ def test_run_function() -> None:
def test_run_exception() -> None:
try:
hosts.run_local(["exit 1"], shell=True)
hosts.run_local(["exit 1"], RunOpts(shell=True))
except Exception:
pass
else:
@@ -62,5 +66,5 @@ def test_run_function_exception() -> None:
def test_run_local_non_shell() -> None:
p2 = hosts.run_local(["echo", "1"], log=Log.STDERR)
p2 = hosts.run_local(["echo", "1"], RunOpts(log=Log.STDERR))
assert p2[0].result.stdout == "1\n"

View File

@@ -21,21 +21,23 @@ def test_parse_ipv6() -> None:
def test_run(host_group: HostGroup) -> None:
proc = host_group.run_local(["echo", "hello"], log=Log.STDERR)
proc = host_group.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
assert proc[0].result.stdout == "hello\n"
def test_run_environment(host_group: HostGroup) -> None:
p1 = host_group.run(
["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR
["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert p1[0].result.stdout == "true\n"
p2 = host_group.run(["env"], log=Log.STDERR, extra_env={"env_var": "true"})
p2 = host_group.run(["env"], RunOpts(log=Log.STDERR), extra_env={"env_var": "true"})
assert "env_var=true" in p2[0].result.stdout
def test_run_no_shell(host_group: HostGroup) -> None:
proc = host_group.run(["echo", "$hello"], log=Log.STDERR)
proc = host_group.run(["echo", "$hello"], RunOpts(log=Log.STDERR))
assert proc[0].result.stdout == "$hello\n"
@@ -50,7 +52,7 @@ def test_run_function(host_group: HostGroup) -> None:
def test_timeout(host_group: HostGroup) -> None:
try:
host_group.run_local(["sleep", "10"], timeout=0.01)
host_group.run_local(["sleep", "10"], RunOpts(timeout=0.01))
except Exception:
pass
else:
@@ -59,11 +61,11 @@ def test_timeout(host_group: HostGroup) -> None:
def test_run_exception(host_group: HostGroup) -> None:
r = host_group.run(["exit 1"], check=False, shell=True)
r = host_group.run(["exit 1"], RunOpts(check=False, shell=True))
assert r[0].result.returncode == 1
try:
host_group.run(["exit 1"], shell=True)
host_group.run(["exit 1"], RunOpts(shell=True))
except Exception:
pass
else:

View File

@@ -51,7 +51,7 @@ def test_run(
"user1",
]
)
cli.run(["vms", "run", "--no-block", "vm1", "shutdown", "-h", "now"])
cli.run(["vms", "run", "--no-block", "vm1", "-c", "shutdown", "-h", "now"])
@pytest.mark.skipif(no_kvm, reason="Requires KVM")