diff --git a/pkgs/clan-cli/clan_cli/ssh/host.py b/pkgs/clan-cli/clan_cli/ssh/host.py index e5ada8d64..6aac7d3e8 100644 --- a/pkgs/clan-cli/clan_cli/ssh/host.py +++ b/pkgs/clan-cli/clan_cli/ssh/host.py @@ -9,8 +9,7 @@ from pathlib import Path from shlex import quote from typing import IO, Any -from clan_cli.cmd import CmdOut, Log, MsgColor, RunOpts -from clan_cli.cmd import run as local_run +from clan_cli.cmd import CmdOut, Log, MsgColor, RunOpts, run from clan_cli.colors import AnsiColor from clan_cli.ssh.host_key import HostKeyCheck @@ -49,60 +48,18 @@ class Host: host = f"[{host}]" return f"{self.user or 'root'}@{host}" - def _run( - self, - cmd: list[str], - *, - stdout: IO[bytes] | None = None, - stderr: IO[bytes] | None = None, - input: bytes | None = None, # noqa: A002 - env: dict[str, str] | None = None, - cwd: Path | None = None, - log: Log = Log.BOTH, - check: bool = True, - error_msg: str | None = None, - needs_user_terminal: bool = False, - msg_color: MsgColor | None = None, - shell: bool = False, - timeout: float = math.inf, - ) -> CmdOut: - res = local_run( - cmd, - RunOpts( - shell=shell, - stdout=stdout, - prefix=self.command_prefix, - timeout=timeout, - stderr=stderr, - input=input, - env=env, - cwd=cwd, - log=log, - check=check, - error_msg=error_msg, - msg_color=msg_color, - needs_user_terminal=needs_user_terminal, - ), - ) - return res - def run_local( self, cmd: list[str], - stdout: IO[bytes] | None = None, - stderr: IO[bytes] | None = None, + opts: RunOpts | None = None, extra_env: dict[str, str] | None = None, - cwd: None | Path = None, - check: bool = True, - timeout: float = math.inf, - shell: bool = False, - needs_user_terminal: bool = False, - log: Log = Log.BOTH, ) -> CmdOut: """ Command to run locally for the host """ - env = os.environ.copy() + if opts is None: + opts = RunOpts() + env = opts.env or os.environ.copy() if extra_env: env.update(extra_env) @@ -114,18 +71,9 @@ class Host: "color": AnsiColor.GREEN.value, }, ) - return self._run( - cmd, - shell=shell, - stdout=stdout, - stderr=stderr, - env=env, - cwd=cwd, - check=check, - needs_user_terminal=needs_user_terminal, - timeout=timeout, - log=log, - ) + opts.env = env + opts.prefix = self.command_prefix + return run(cmd, opts) def run( self, @@ -187,20 +135,22 @@ class Host: f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}", ] - # Run the ssh command - return self._run( - ssh_cmd, + opts = RunOpts( shell=False, stdout=stdout, stderr=stderr, log=log, cwd=cwd, check=check, + prefix=self.command_prefix, timeout=timeout, msg_color=msg_color, needs_user_terminal=True, # ssh asks for a password ) + # Run the ssh command + return run(ssh_cmd, opts) + def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]: if env is None: env = {} diff --git a/pkgs/clan-cli/clan_cli/ssh/host_group.py b/pkgs/clan-cli/clan_cli/ssh/host_group.py index bf1890c16..fae849e9d 100644 --- a/pkgs/clan-cli/clan_cli/ssh/host_group.py +++ b/pkgs/clan-cli/clan_cli/ssh/host_group.py @@ -6,7 +6,7 @@ from pathlib import Path from threading import Thread from typing import IO, Any -from clan_cli.cmd import Log +from clan_cli.cmd import Log, RunOpts from clan_cli.errors import ClanError from clan_cli.ssh import T from clan_cli.ssh.host import Host @@ -33,33 +33,22 @@ class HostGroup: def _run_local( self, + *, cmd: list[str], + opts: RunOpts, + extra_env: dict[str, str] | None, host: Host, results: Results, - stdout: IO[bytes] | None = None, - stderr: IO[bytes] | None = None, - extra_env: dict[str, str] | None = None, - cwd: None | Path = None, - check: bool = True, verbose_ssh: bool = False, - timeout: float = math.inf, - shell: bool = False, tty: bool = False, - log: Log = Log.BOTH, ) -> None: if extra_env is None: extra_env = {} try: proc = host.run_local( cmd, - stdout=stdout, - stderr=stderr, - extra_env=extra_env, - cwd=cwd, - check=check, - timeout=timeout, - shell=shell, - log=log, + opts, + extra_env, ) results.append(HostResult(host, proc)) except Exception as e: @@ -121,49 +110,59 @@ class HostGroup: def _run( self, cmd: list[str], + opts: RunOpts | None = None, local: bool = False, - stdout: IO[bytes] | None = None, - stderr: IO[bytes] | None = None, extra_env: dict[str, str] | None = None, - cwd: None | str | Path = None, - check: bool = True, - timeout: float = math.inf, verbose_ssh: bool = False, tty: bool = False, - shell: bool = False, - log: Log = Log.BOTH, ) -> Results: + if opts is None: + opts = RunOpts() if extra_env is None: extra_env = {} results: Results = [] threads = [] for host in self.hosts: - fn = self._run_local if local else self._run_remote - thread = Thread( - target=fn, - kwargs={ - "results": results, - "cmd": cmd, - "host": host, - "stdout": stdout, - "stderr": stderr, - "extra_env": extra_env, - "cwd": cwd, - "check": check, - "timeout": timeout, - "verbose_ssh": verbose_ssh, - "tty": tty, - "shell": shell, - "log": log, - }, - ) + if local: + thread = Thread( + target=self._run_local, + kwargs={ + "cmd": cmd, + "opts": opts, + "host": host, + "results": results, + "extra_env": extra_env, + "verbose_ssh": verbose_ssh, + "tty": tty, + }, + ) + else: + thread = Thread( + target=self._run_remote, + kwargs={ + "results": results, + "cmd": cmd, + "host": host, + "stdout": opts.stdout, + "stderr": opts.stderr, + "extra_env": extra_env, + "cwd": opts.cwd, + "check": opts.check, + "timeout": opts.timeout, + "verbose_ssh": verbose_ssh, + "tty": tty, + "shell": opts.shell, + "log": opts.log, + }, + ) + thread.start() threads.append(thread) for thread in threads: thread.join() - if check: + if opts.check: self._reraise_errors(results) return results @@ -174,7 +173,7 @@ class HostGroup: stdout: IO[bytes] | None = None, stderr: IO[bytes] | None = None, extra_env: dict[str, str] | None = None, - cwd: None | str | Path = None, + cwd: None | Path = None, check: bool = True, verbose_ssh: bool = False, timeout: float = math.inf, @@ -189,18 +188,21 @@ class HostGroup: """ if extra_env is None: extra_env = {} - return self._run( - cmd, + opts = RunOpts( shell=shell, stdout=stdout, stderr=stderr, - extra_env=extra_env, + log=log, + timeout=timeout, cwd=cwd, check=check, + ) + return self._run( + cmd, + opts, + extra_env=extra_env, verbose_ssh=verbose_ssh, - timeout=timeout, tty=tty, - log=log, ) def run_local( @@ -209,7 +211,7 @@ class HostGroup: stdout: IO[bytes] | None = None, stderr: IO[bytes] | None = None, extra_env: dict[str, str] | None = None, - cwd: None | str | Path = None, + cwd: None | Path = None, check: bool = True, timeout: float = math.inf, shell: bool = False, @@ -222,18 +224,21 @@ class HostGroup: """ if extra_env is None: extra_env = {} - return self._run( - cmd, - local=True, + opts = RunOpts( stdout=stdout, stderr=stderr, - extra_env=extra_env, cwd=cwd, check=check, timeout=timeout, shell=shell, log=log, ) + return self._run( + cmd, + opts, + local=True, + extra_env=extra_env, + ) def run_function( self, func: Callable[[Host], T], check: bool = True diff --git a/pkgs/clan-cli/tests/test_ssh_local.py b/pkgs/clan-cli/tests/test_ssh_local.py index 6e2737569..f8d21b065 100644 --- a/pkgs/clan-cli/tests/test_ssh_local.py +++ b/pkgs/clan-cli/tests/test_ssh_local.py @@ -1,4 +1,4 @@ -from clan_cli.cmd import Log +from clan_cli.cmd import Log, RunOpts from clan_cli.ssh.host import Host from clan_cli.ssh.host_group import HostGroup @@ -31,7 +31,7 @@ def test_timeout() -> None: def test_run_function() -> None: def some_func(h: Host) -> bool: - par = h.run_local(["echo", "hello"], log=Log.STDERR) + par = h.run_local(["echo", "hello"], RunOpts(log=Log.STDERR)) return par.stdout == "hello\n" res = hosts.run_function(some_func) @@ -50,7 +50,7 @@ def test_run_exception() -> None: def test_run_function_exception() -> None: def some_func(h: Host) -> None: - h.run_local(["exit 1"], shell=True) + h.run_local(["exit 1"], RunOpts(shell=True)) try: hosts.run_function(some_func) diff --git a/pkgs/clan-cli/tests/test_ssh_remote.py b/pkgs/clan-cli/tests/test_ssh_remote.py index 851728a9f..204a2cdd8 100644 --- a/pkgs/clan-cli/tests/test_ssh_remote.py +++ b/pkgs/clan-cli/tests/test_ssh_remote.py @@ -1,5 +1,5 @@ import pytest -from clan_cli.cmd import Log +from clan_cli.cmd import Log, RunOpts from clan_cli.errors import ClanError, CmdOut from clan_cli.ssh.host import Host from clan_cli.ssh.host_group import HostGroup @@ -73,7 +73,7 @@ def test_run_exception(host_group: HostGroup) -> None: def test_run_function_exception(host_group: HostGroup) -> None: def some_func(h: Host) -> CmdOut: - return h.run_local(["exit 1"], shell=True) + return h.run_local(["exit 1"], RunOpts(shell=True)) try: host_group.run_function(some_func)