From b9154fddd296ca2a0828c022531965d677b85065 Mon Sep 17 00:00:00 2001 From: Qubasa Date: Mon, 25 Nov 2024 19:47:17 +0100 Subject: [PATCH] clan-cli: Refactor ssh classes to dataclasses --- pkgs/clan-cli/clan_cli/facts/upload.py | 3 +- pkgs/clan-cli/clan_cli/ssh/host.py | 158 +++-------------------- pkgs/clan-cli/clan_cli/ssh/host_group.py | 24 +--- pkgs/clan-cli/clan_cli/ssh/results.py | 7 +- pkgs/clan-cli/clan_cli/ssh/upload.py | 84 ++++++++++++ pkgs/clan-cli/clan_cli/vars/upload.py | 5 +- pkgs/clan-cli/tests/nix_config.py | 2 +- 7 files changed, 116 insertions(+), 167 deletions(-) create mode 100644 pkgs/clan-cli/clan_cli/ssh/upload.py diff --git a/pkgs/clan-cli/clan_cli/facts/upload.py b/pkgs/clan-cli/clan_cli/facts/upload.py index 49dcddb2b..47bda5e74 100644 --- a/pkgs/clan-cli/clan_cli/facts/upload.py +++ b/pkgs/clan-cli/clan_cli/facts/upload.py @@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.machines.machines import Machine +from clan_cli.ssh.upload import upload log = logging.getLogger(__name__) @@ -23,7 +24,7 @@ def upload_secrets(machine: Machine) -> None: secret_facts_store.upload(local_secret_dir) remote_secret_dir = Path(machine.secrets_upload_directory) - machine.target_host.upload(local_secret_dir, remote_secret_dir) + upload(machine.target_host, local_secret_dir, remote_secret_dir) def upload_command(args: argparse.Namespace) -> None: diff --git a/pkgs/clan-cli/clan_cli/ssh/host.py b/pkgs/clan-cli/clan_cli/ssh/host.py index 52daca9bd..a45aefc26 100644 --- a/pkgs/clan-cli/clan_cli/ssh/host.py +++ b/pkgs/clan-cli/clan_cli/ssh/host.py @@ -5,15 +5,13 @@ import math import os import shlex import subprocess -import tarfile +from dataclasses import dataclass, field from pathlib import Path from shlex import quote -from tempfile import TemporaryDirectory from typing import IO, Any from clan_cli.cmd import Log from clan_cli.cmd import run as local_run -from clan_cli.errors import ClanError from clan_cli.ssh.host_key import HostKeyCheck cmdlog = logging.getLogger(__name__) @@ -23,47 +21,22 @@ cmdlog = logging.getLogger(__name__) NO_OUTPUT_TIMEOUT = 20 +@dataclass class Host: - def __init__( - self, - host: str, - user: str | None = None, - port: int | None = None, - key: str | None = None, - forward_agent: bool = False, - command_prefix: str | None = None, - host_key_check: HostKeyCheck = HostKeyCheck.ASK, - meta: dict[str, Any] | None = None, - verbose_ssh: bool = False, - ssh_options: dict[str, str] | None = None, - ) -> None: - """ - Creates a Host - @host the hostname to connect to via ssh - @port the port to connect to via ssh - @forward_agent: whether to forward ssh agent - @command_prefix: string to prefix each line of the command output with, defaults to host - @host_key_check: whether to check ssh host keys - @verbose_ssh: Enables verbose logging on ssh connections - @meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function` - """ - if ssh_options is None: - ssh_options = {} - if meta is None: - meta = {} - self.host = host - self.user = user - self.port = port - self.key = key - if command_prefix: - self.command_prefix = command_prefix - else: - self.command_prefix = host - self.forward_agent = forward_agent - self.host_key_check = host_key_check - self.meta = meta - self.verbose_ssh = verbose_ssh - self._ssh_options = ssh_options + host: str + user: str | None = None + port: int | None = None + key: str | None = None + forward_agent: bool = False + command_prefix: str | None = None + host_key_check: HostKeyCheck = HostKeyCheck.ASK + meta: dict[str, Any] = field(default_factory=dict) + verbose_ssh: bool = False + ssh_options: dict[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + if not self.command_prefix: + self.command_prefix = self.host def __repr__(self) -> str: return str(self) @@ -131,19 +104,11 @@ class Host: check: bool = True, timeout: float = math.inf, shell: bool = False, + needs_user_terminal: bool = False, log: Log = Log.BOTH, ) -> subprocess.CompletedProcess[str]: """ Command to run locally for the host - - @cmd the command to run - @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocess.PIPE - @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE - @extra_env environment variables to override when running the command - @cwd current working directory to run the process in - @timeout: Timeout in seconds for the command to complete - - @return subprocess.CompletedProcess result of the command """ env = os.environ.copy() if extra_env: @@ -159,6 +124,7 @@ class Host: env=env, cwd=cwd, check=check, + needs_user_terminal=needs_user_terminal, timeout=timeout, log=log, ) @@ -180,17 +146,6 @@ class Host: ) -> subprocess.CompletedProcess[str]: """ Command to run on the host via ssh - - @cmd the command to run - @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE - @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE - @become_root if the ssh_user is not root than sudo is prepended - @extra_env environment variables to override when running the command - @cwd current working directory to run the process in - @verbose_ssh: Enables verbose logging on ssh connections - @timeout: Timeout in seconds for the command to complete - - @return subprocess.CompletedProcess result of the ssh command """ if extra_env is None: extra_env = {} @@ -246,88 +201,13 @@ class Host: env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts) return env - def upload( - self, - local_src: Path, # must be a directory - remote_dest: Path, # must be a directory - file_user: str = "root", - file_group: str = "root", - dir_mode: int = 0o700, - file_mode: int = 0o400, - ) -> None: - # check if the remote destination is a directory (no suffix) - if remote_dest.suffix: - msg = "Only directories are allowed" - raise ClanError(msg) - - if not local_src.is_dir(): - msg = "Only directories are allowed" - raise ClanError(msg) - - # Create the tarball from the temporary directory - with TemporaryDirectory(prefix="facts-upload-") as tardir: - tar_path = Path(tardir) / "upload.tar.gz" - # We set the permissions of the files and directories in the tarball to read only and owned by root - # As first uploading the tarball and then changing the permissions can lead an attacker to - # do a race condition attack - with tarfile.open(str(tar_path), "w:gz") as tar: - for root, dirs, files in local_src.walk(): - for mdir in dirs: - dir_path = Path(root) / mdir - tarinfo = tar.gettarinfo( - dir_path, arcname=str(dir_path.relative_to(str(local_src))) - ) - tarinfo.mode = dir_mode - tarinfo.uname = file_user - tarinfo.gname = file_group - tar.addfile(tarinfo) - for file in files: - file_path = Path(root) / file - tarinfo = tar.gettarinfo( - file_path, - arcname=str(file_path.relative_to(str(local_src))), - ) - tarinfo.mode = file_mode - tarinfo.uname = file_user - tarinfo.gname = file_group - with file_path.open("rb") as f: - tar.addfile(tarinfo, f) - - cmd = [ - *self.ssh_cmd(), - "rm", - "-r", - str(remote_dest), - ";", - "mkdir", - f"--mode={dir_mode:o}", - "-p", - str(remote_dest), - "&&", - "tar", - "-C", - str(remote_dest), - "-xzf", - "-", - ] - - # TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory. - with tar_path.open("rb") as f: - local_run( - cmd, - input=f.read(), - log=Log.BOTH, - needs_user_terminal=True, - prefix=self.command_prefix, - ) - @property def ssh_cmd_opts( self, ) -> list[str]: ssh_opts = ["-A"] if self.forward_agent else [] - for k, v in self._ssh_options.items(): + for k, v in self.ssh_options.items(): ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"]) ssh_opts.extend(self.host_key_check.to_ssh_opt()) diff --git a/pkgs/clan-cli/clan_cli/ssh/host_group.py b/pkgs/clan-cli/clan_cli/ssh/host_group.py index 425681cfa..c6434441d 100644 --- a/pkgs/clan-cli/clan_cli/ssh/host_group.py +++ b/pkgs/clan-cli/clan_cli/ssh/host_group.py @@ -1,6 +1,7 @@ import logging import math from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path from threading import Thread from typing import IO, Any @@ -26,15 +27,9 @@ def _worker( results[idx] = HostResult(host, e) +@dataclass class HostGroup: - def __init__(self, hosts: list[Host]) -> None: - self.hosts = hosts - - def __repr__(self) -> str: - return str(self) - - def __str__(self) -> str: - return f"HostGroup({self.hosts})" + hosts: list[Host] def _run_local( self, @@ -189,11 +184,6 @@ class HostGroup: ) -> Results: """ Command to run on the remote host via ssh - @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE - @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE - @cwd current working directory to run the process in - @verbose_ssh: Enables verbose logging on ssh connections - @timeout: Timeout in seconds for the command to complete @return a lists of tuples containing Host and the result of the command for this Host """ @@ -227,12 +217,6 @@ class HostGroup: ) -> Results: """ Command to run locally for each host in the group in parallel - @cmd the command to run - @stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE - @stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE - @cwd current working directory to run the process in - @extra_env environment variables to override when running the command - @timeout: Timeout in seconds for the command to complete @return a lists of tuples containing Host and the result of the command for this Host """ @@ -256,8 +240,6 @@ class HostGroup: ) -> list[HostResult[T]]: """ Function to run for each host in the group in parallel - - @func the function to call """ threads = [] results: list[HostResult[T]] = [ diff --git a/pkgs/clan-cli/clan_cli/ssh/results.py b/pkgs/clan-cli/clan_cli/ssh/results.py index 45a69efa5..727ad3dd9 100644 --- a/pkgs/clan-cli/clan_cli/ssh/results.py +++ b/pkgs/clan-cli/clan_cli/ssh/results.py @@ -1,14 +1,15 @@ import subprocess +from dataclasses import dataclass from typing import Generic from clan_cli.ssh import T from clan_cli.ssh.host import Host +@dataclass class HostResult(Generic[T]): - def __init__(self, host: Host, result: T | Exception) -> None: - self.host = host - self._result = result + host: Host + _result: T | Exception @property def error(self) -> Exception | None: diff --git a/pkgs/clan-cli/clan_cli/ssh/upload.py b/pkgs/clan-cli/clan_cli/ssh/upload.py new file mode 100644 index 000000000..747f4f0d0 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/ssh/upload.py @@ -0,0 +1,84 @@ +import tarfile +from pathlib import Path +from tempfile import TemporaryDirectory + +from clan_cli.cmd import Log +from clan_cli.cmd import run as run_local +from clan_cli.errors import ClanError +from clan_cli.ssh.host import Host + + +def upload( + host: Host, + local_src: Path, # must be a directory + remote_dest: Path, # must be a directory + file_user: str = "root", + file_group: str = "root", + dir_mode: int = 0o700, + file_mode: int = 0o400, +) -> None: + # check if the remote destination is a directory (no suffix) + if remote_dest.suffix: + msg = "Only directories are allowed" + raise ClanError(msg) + + if not local_src.is_dir(): + msg = "Only directories are allowed" + raise ClanError(msg) + + # Create the tarball from the temporary directory + with TemporaryDirectory(prefix="facts-upload-") as tardir: + tar_path = Path(tardir) / "upload.tar.gz" + # We set the permissions of the files and directories in the tarball to read only and owned by root + # As first uploading the tarball and then changing the permissions can lead an attacker to + # do a race condition attack + with tarfile.open(str(tar_path), "w:gz") as tar: + for root, dirs, files in local_src.walk(): + for mdir in dirs: + dir_path = Path(root) / mdir + tarinfo = tar.gettarinfo( + dir_path, arcname=str(dir_path.relative_to(str(local_src))) + ) + tarinfo.mode = dir_mode + tarinfo.uname = file_user + tarinfo.gname = file_group + tar.addfile(tarinfo) + for file in files: + file_path = Path(root) / file + tarinfo = tar.gettarinfo( + file_path, + arcname=str(file_path.relative_to(str(local_src))), + ) + tarinfo.mode = file_mode + tarinfo.uname = file_user + tarinfo.gname = file_group + with file_path.open("rb") as f: + tar.addfile(tarinfo, f) + + cmd = [ + *host.ssh_cmd(), + "rm", + "-r", + str(remote_dest), + ";", + "mkdir", + f"--mode={dir_mode:o}", + "-p", + str(remote_dest), + "&&", + "tar", + "-C", + str(remote_dest), + "-xzf", + "-", + ] + + # TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory. + with tar_path.open("rb") as f: + run_local( + cmd, + input=f.read(), + log=Log.BOTH, + prefix=host.command_prefix, + needs_user_terminal=True, + ) diff --git a/pkgs/clan-cli/clan_cli/vars/upload.py b/pkgs/clan-cli/clan_cli/vars/upload.py index ce22cad67..8c78974e1 100644 --- a/pkgs/clan-cli/clan_cli/vars/upload.py +++ b/pkgs/clan-cli/clan_cli/vars/upload.py @@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.machines.machines import Machine +from clan_cli.ssh.upload import upload log = logging.getLogger(__name__) @@ -20,8 +21,8 @@ def upload_secret_vars(machine: Machine) -> None: with TemporaryDirectory(prefix="vars-upload-") as tempdir: secret_dir = Path(tempdir) secret_store.upload(secret_dir) - machine.target_host.upload( - secret_dir, Path(machine.secret_vars_upload_directory) + upload( + machine.target_host, secret_dir, Path(machine.secret_vars_upload_directory) ) diff --git a/pkgs/clan-cli/tests/nix_config.py b/pkgs/clan-cli/tests/nix_config.py index 2f0d753ae..3f142542f 100644 --- a/pkgs/clan-cli/tests/nix_config.py +++ b/pkgs/clan-cli/tests/nix_config.py @@ -18,7 +18,7 @@ class ConfigItem: @pytest.fixture(scope="session") def nix_config() -> dict[str, ConfigItem]: proc = subprocess.run( - ["nix", "show-config", "--json"], check=True, stdout=subprocess.PIPE + ["nix", "config", "show", "--json"], check=True, stdout=subprocess.PIPE ) data = json.loads(proc.stdout) return {name: ConfigItem(**c) for name, c in data.items()}