clan-cli: refactor HostGroup._run_local to work with RunOpts

This commit is contained in:
Qubasa
2024-11-28 16:05:51 +01:00
parent 466044e85f
commit 95cb239206
4 changed files with 78 additions and 123 deletions

View File

@@ -9,8 +9,7 @@ from pathlib import Path
from shlex import quote from shlex import quote
from typing import IO, Any from typing import IO, Any
from clan_cli.cmd import CmdOut, Log, MsgColor, RunOpts from clan_cli.cmd import CmdOut, Log, MsgColor, RunOpts, run
from clan_cli.cmd import run as local_run
from clan_cli.colors import AnsiColor from clan_cli.colors import AnsiColor
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
@@ -49,60 +48,18 @@ class Host:
host = f"[{host}]" host = f"[{host}]"
return f"{self.user or 'root'}@{host}" return f"{self.user or 'root'}@{host}"
def _run(
self,
cmd: list[str],
*,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
input: bytes | None = None, # noqa: A002
env: dict[str, str] | None = None,
cwd: Path | None = None,
log: Log = Log.BOTH,
check: bool = True,
error_msg: str | None = None,
needs_user_terminal: bool = False,
msg_color: MsgColor | None = None,
shell: bool = False,
timeout: float = math.inf,
) -> CmdOut:
res = local_run(
cmd,
RunOpts(
shell=shell,
stdout=stdout,
prefix=self.command_prefix,
timeout=timeout,
stderr=stderr,
input=input,
env=env,
cwd=cwd,
log=log,
check=check,
error_msg=error_msg,
msg_color=msg_color,
needs_user_terminal=needs_user_terminal,
),
)
return res
def run_local( def run_local(
self, self,
cmd: list[str], cmd: list[str],
stdout: IO[bytes] | None = None, opts: RunOpts | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
shell: bool = False,
needs_user_terminal: bool = False,
log: Log = Log.BOTH,
) -> CmdOut: ) -> CmdOut:
""" """
Command to run locally for the host Command to run locally for the host
""" """
env = os.environ.copy() if opts is None:
opts = RunOpts()
env = opts.env or os.environ.copy()
if extra_env: if extra_env:
env.update(extra_env) env.update(extra_env)
@@ -114,18 +71,9 @@ class Host:
"color": AnsiColor.GREEN.value, "color": AnsiColor.GREEN.value,
}, },
) )
return self._run( opts.env = env
cmd, opts.prefix = self.command_prefix
shell=shell, return run(cmd, opts)
stdout=stdout,
stderr=stderr,
env=env,
cwd=cwd,
check=check,
needs_user_terminal=needs_user_terminal,
timeout=timeout,
log=log,
)
def run( def run(
self, self,
@@ -187,20 +135,22 @@ class Host:
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}", f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}",
] ]
# Run the ssh command opts = RunOpts(
return self._run(
ssh_cmd,
shell=False, shell=False,
stdout=stdout, stdout=stdout,
stderr=stderr, stderr=stderr,
log=log, log=log,
cwd=cwd, cwd=cwd,
check=check, check=check,
prefix=self.command_prefix,
timeout=timeout, timeout=timeout,
msg_color=msg_color, msg_color=msg_color,
needs_user_terminal=True, # ssh asks for a password needs_user_terminal=True, # ssh asks for a password
) )
# Run the ssh command
return run(ssh_cmd, opts)
def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]: def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]:
if env is None: if env is None:
env = {} env = {}

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from threading import Thread from threading import Thread
from typing import IO, Any from typing import IO, Any
from clan_cli.cmd import Log from clan_cli.cmd import Log, RunOpts
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.ssh import T from clan_cli.ssh import T
from clan_cli.ssh.host import Host from clan_cli.ssh.host import Host
@@ -33,33 +33,22 @@ class HostGroup:
def _run_local( def _run_local(
self, self,
*,
cmd: list[str], cmd: list[str],
opts: RunOpts,
extra_env: dict[str, str] | None,
host: Host, host: Host,
results: Results, results: Results,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | Path = None,
check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf,
shell: bool = False,
tty: bool = False, tty: bool = False,
log: Log = Log.BOTH,
) -> None: ) -> None:
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
try: try:
proc = host.run_local( proc = host.run_local(
cmd, cmd,
stdout=stdout, opts,
stderr=stderr, extra_env,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
shell=shell,
log=log,
) )
results.append(HostResult(host, proc)) results.append(HostResult(host, proc))
except Exception as e: except Exception as e:
@@ -121,49 +110,59 @@ class HostGroup:
def _run( def _run(
self, self,
cmd: list[str], cmd: list[str],
opts: RunOpts | None = None,
local: bool = False, local: bool = False,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results: ) -> Results:
if opts is None:
opts = RunOpts()
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
results: Results = [] results: Results = []
threads = [] threads = []
for host in self.hosts: for host in self.hosts:
fn = self._run_local if local else self._run_remote if local:
thread = Thread( thread = Thread(
target=fn, target=self._run_local,
kwargs={
"cmd": cmd,
"opts": opts,
"host": host,
"results": results,
"extra_env": extra_env,
"verbose_ssh": verbose_ssh,
"tty": tty,
},
)
else:
thread = Thread(
target=self._run_remote,
kwargs={ kwargs={
"results": results, "results": results,
"cmd": cmd, "cmd": cmd,
"host": host, "host": host,
"stdout": stdout, "stdout": opts.stdout,
"stderr": stderr, "stderr": opts.stderr,
"extra_env": extra_env, "extra_env": extra_env,
"cwd": cwd, "cwd": opts.cwd,
"check": check, "check": opts.check,
"timeout": timeout, "timeout": opts.timeout,
"verbose_ssh": verbose_ssh, "verbose_ssh": verbose_ssh,
"tty": tty, "tty": tty,
"shell": shell, "shell": opts.shell,
"log": log, "log": opts.log,
}, },
) )
thread.start() thread.start()
threads.append(thread) threads.append(thread)
for thread in threads: for thread in threads:
thread.join() thread.join()
if check: if opts.check:
self._reraise_errors(results) self._reraise_errors(results)
return results return results
@@ -174,7 +173,7 @@ class HostGroup:
stdout: IO[bytes] | None = None, stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf, timeout: float = math.inf,
@@ -189,18 +188,21 @@ class HostGroup:
""" """
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
return self._run( opts = RunOpts(
cmd,
shell=shell, shell=shell,
stdout=stdout, stdout=stdout,
stderr=stderr, stderr=stderr,
extra_env=extra_env, log=log,
timeout=timeout,
cwd=cwd, cwd=cwd,
check=check, check=check,
)
return self._run(
cmd,
opts,
extra_env=extra_env,
verbose_ssh=verbose_ssh, verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty, tty=tty,
log=log,
) )
def run_local( def run_local(
@@ -209,7 +211,7 @@ class HostGroup:
stdout: IO[bytes] | None = None, stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
shell: bool = False, shell: bool = False,
@@ -222,18 +224,21 @@ class HostGroup:
""" """
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
return self._run( opts = RunOpts(
cmd,
local=True,
stdout=stdout, stdout=stdout,
stderr=stderr, stderr=stderr,
extra_env=extra_env,
cwd=cwd, cwd=cwd,
check=check, check=check,
timeout=timeout, timeout=timeout,
shell=shell, shell=shell,
log=log, log=log,
) )
return self._run(
cmd,
opts,
local=True,
extra_env=extra_env,
)
def run_function( def run_function(
self, func: Callable[[Host], T], check: bool = True self, func: Callable[[Host], T], check: bool = True

View File

@@ -1,4 +1,4 @@
from clan_cli.cmd import Log from clan_cli.cmd import Log, RunOpts
from clan_cli.ssh.host import Host from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup from clan_cli.ssh.host_group import HostGroup
@@ -31,7 +31,7 @@ def test_timeout() -> None:
def test_run_function() -> None: def test_run_function() -> None:
def some_func(h: Host) -> bool: def some_func(h: Host) -> bool:
par = h.run_local(["echo", "hello"], log=Log.STDERR) par = h.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
return par.stdout == "hello\n" return par.stdout == "hello\n"
res = hosts.run_function(some_func) res = hosts.run_function(some_func)
@@ -50,7 +50,7 @@ def test_run_exception() -> None:
def test_run_function_exception() -> None: def test_run_function_exception() -> None:
def some_func(h: Host) -> None: def some_func(h: Host) -> None:
h.run_local(["exit 1"], shell=True) h.run_local(["exit 1"], RunOpts(shell=True))
try: try:
hosts.run_function(some_func) hosts.run_function(some_func)

View File

@@ -1,5 +1,5 @@
import pytest import pytest
from clan_cli.cmd import Log from clan_cli.cmd import Log, RunOpts
from clan_cli.errors import ClanError, CmdOut from clan_cli.errors import ClanError, CmdOut
from clan_cli.ssh.host import Host from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup from clan_cli.ssh.host_group import HostGroup
@@ -73,7 +73,7 @@ def test_run_exception(host_group: HostGroup) -> None:
def test_run_function_exception(host_group: HostGroup) -> None: def test_run_function_exception(host_group: HostGroup) -> None:
def some_func(h: Host) -> CmdOut: def some_func(h: Host) -> CmdOut:
return h.run_local(["exit 1"], shell=True) return h.run_local(["exit 1"], RunOpts(shell=True))
try: try:
host_group.run_function(some_func) host_group.run_function(some_func)