clan-cli: Refactor ssh classes to dataclasses

This commit is contained in:
Qubasa
2024-11-25 19:47:17 +01:00
parent e16990e493
commit b9154fddd2
7 changed files with 116 additions and 167 deletions

View File

@@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory
from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.machines.machines import Machine
from clan_cli.ssh.upload import upload
log = logging.getLogger(__name__)
@@ -23,7 +24,7 @@ def upload_secrets(machine: Machine) -> None:
secret_facts_store.upload(local_secret_dir)
remote_secret_dir = Path(machine.secrets_upload_directory)
machine.target_host.upload(local_secret_dir, remote_secret_dir)
upload(machine.target_host, local_secret_dir, remote_secret_dir)
def upload_command(args: argparse.Namespace) -> None:

View File

@@ -5,15 +5,13 @@ import math
import os
import shlex
import subprocess
import tarfile
from dataclasses import dataclass, field
from pathlib import Path
from shlex import quote
from tempfile import TemporaryDirectory
from typing import IO, Any
from clan_cli.cmd import Log
from clan_cli.cmd import run as local_run
from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck
cmdlog = logging.getLogger(__name__)
@@ -23,47 +21,22 @@ cmdlog = logging.getLogger(__name__)
NO_OUTPUT_TIMEOUT = 20
@dataclass
class Host:
def __init__(
self,
host: str,
user: str | None = None,
port: int | None = None,
key: str | None = None,
forward_agent: bool = False,
command_prefix: str | None = None,
host_key_check: HostKeyCheck = HostKeyCheck.ASK,
meta: dict[str, Any] | None = None,
verbose_ssh: bool = False,
ssh_options: dict[str, str] | None = None,
) -> None:
"""
Creates a Host
@host the hostname to connect to via ssh
@port the port to connect to via ssh
@forward_agent: whether to forward ssh agent
@command_prefix: string to prefix each line of the command output with, defaults to host
@host_key_check: whether to check ssh host keys
@verbose_ssh: Enables verbose logging on ssh connections
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
"""
if ssh_options is None:
ssh_options = {}
if meta is None:
meta = {}
self.host = host
self.user = user
self.port = port
self.key = key
if command_prefix:
self.command_prefix = command_prefix
else:
self.command_prefix = host
self.forward_agent = forward_agent
self.host_key_check = host_key_check
self.meta = meta
self.verbose_ssh = verbose_ssh
self._ssh_options = ssh_options
host: str
user: str | None = None
port: int | None = None
key: str | None = None
forward_agent: bool = False
command_prefix: str | None = None
host_key_check: HostKeyCheck = HostKeyCheck.ASK
meta: dict[str, Any] = field(default_factory=dict)
verbose_ssh: bool = False
ssh_options: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
if not self.command_prefix:
self.command_prefix = self.host
def __repr__(self) -> str:
return str(self)
@@ -131,19 +104,11 @@ class Host:
check: bool = True,
timeout: float = math.inf,
shell: bool = False,
needs_user_terminal: bool = False,
log: Log = Log.BOTH,
) -> subprocess.CompletedProcess[str]:
"""
Command to run locally for the host
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocess.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@extra_env environment variables to override when running the command
@cwd current working directory to run the process in
@timeout: Timeout in seconds for the command to complete
@return subprocess.CompletedProcess result of the command
"""
env = os.environ.copy()
if extra_env:
@@ -159,6 +124,7 @@ class Host:
env=env,
cwd=cwd,
check=check,
needs_user_terminal=needs_user_terminal,
timeout=timeout,
log=log,
)
@@ -180,17 +146,6 @@ class Host:
) -> subprocess.CompletedProcess[str]:
"""
Command to run on the host via ssh
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@become_root if the ssh_user is not root than sudo is prepended
@extra_env environment variables to override when running the command
@cwd current working directory to run the process in
@verbose_ssh: Enables verbose logging on ssh connections
@timeout: Timeout in seconds for the command to complete
@return subprocess.CompletedProcess result of the ssh command
"""
if extra_env is None:
extra_env = {}
@@ -246,88 +201,13 @@ class Host:
env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts)
return env
def upload(
self,
local_src: Path, # must be a directory
remote_dest: Path, # must be a directory
file_user: str = "root",
file_group: str = "root",
dir_mode: int = 0o700,
file_mode: int = 0o400,
) -> None:
# check if the remote destination is a directory (no suffix)
if remote_dest.suffix:
msg = "Only directories are allowed"
raise ClanError(msg)
if not local_src.is_dir():
msg = "Only directories are allowed"
raise ClanError(msg)
# Create the tarball from the temporary directory
with TemporaryDirectory(prefix="facts-upload-") as tardir:
tar_path = Path(tardir) / "upload.tar.gz"
# We set the permissions of the files and directories in the tarball to read only and owned by root
# As first uploading the tarball and then changing the permissions can lead an attacker to
# do a race condition attack
with tarfile.open(str(tar_path), "w:gz") as tar:
for root, dirs, files in local_src.walk():
for mdir in dirs:
dir_path = Path(root) / mdir
tarinfo = tar.gettarinfo(
dir_path, arcname=str(dir_path.relative_to(str(local_src)))
)
tarinfo.mode = dir_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
tar.addfile(tarinfo)
for file in files:
file_path = Path(root) / file
tarinfo = tar.gettarinfo(
file_path,
arcname=str(file_path.relative_to(str(local_src))),
)
tarinfo.mode = file_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
with file_path.open("rb") as f:
tar.addfile(tarinfo, f)
cmd = [
*self.ssh_cmd(),
"rm",
"-r",
str(remote_dest),
";",
"mkdir",
f"--mode={dir_mode:o}",
"-p",
str(remote_dest),
"&&",
"tar",
"-C",
str(remote_dest),
"-xzf",
"-",
]
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f:
local_run(
cmd,
input=f.read(),
log=Log.BOTH,
needs_user_terminal=True,
prefix=self.command_prefix,
)
@property
def ssh_cmd_opts(
self,
) -> list[str]:
ssh_opts = ["-A"] if self.forward_agent else []
for k, v in self._ssh_options.items():
for k, v in self.ssh_options.items():
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
ssh_opts.extend(self.host_key_check.to_ssh_opt())

View File

@@ -1,6 +1,7 @@
import logging
import math
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from typing import IO, Any
@@ -26,15 +27,9 @@ def _worker(
results[idx] = HostResult(host, e)
@dataclass
class HostGroup:
def __init__(self, hosts: list[Host]) -> None:
self.hosts = hosts
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"HostGroup({self.hosts})"
hosts: list[Host]
def _run_local(
self,
@@ -189,11 +184,6 @@ class HostGroup:
) -> Results:
"""
Command to run on the remote host via ssh
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@cwd current working directory to run the process in
@verbose_ssh: Enables verbose logging on ssh connections
@timeout: Timeout in seconds for the command to complete
@return a lists of tuples containing Host and the result of the command for this Host
"""
@@ -227,12 +217,6 @@ class HostGroup:
) -> Results:
"""
Command to run locally for each host in the group in parallel
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@cwd current working directory to run the process in
@extra_env environment variables to override when running the command
@timeout: Timeout in seconds for the command to complete
@return a lists of tuples containing Host and the result of the command for this Host
"""
@@ -256,8 +240,6 @@ class HostGroup:
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
@func the function to call
"""
threads = []
results: list[HostResult[T]] = [

View File

@@ -1,14 +1,15 @@
import subprocess
from dataclasses import dataclass
from typing import Generic
from clan_cli.ssh import T
from clan_cli.ssh.host import Host
@dataclass
class HostResult(Generic[T]):
def __init__(self, host: Host, result: T | Exception) -> None:
self.host = host
self._result = result
host: Host
_result: T | Exception
@property
def error(self) -> Exception | None:

View File

@@ -0,0 +1,84 @@
import tarfile
from pathlib import Path
from tempfile import TemporaryDirectory
from clan_cli.cmd import Log
from clan_cli.cmd import run as run_local
from clan_cli.errors import ClanError
from clan_cli.ssh.host import Host
def upload(
host: Host,
local_src: Path, # must be a directory
remote_dest: Path, # must be a directory
file_user: str = "root",
file_group: str = "root",
dir_mode: int = 0o700,
file_mode: int = 0o400,
) -> None:
# check if the remote destination is a directory (no suffix)
if remote_dest.suffix:
msg = "Only directories are allowed"
raise ClanError(msg)
if not local_src.is_dir():
msg = "Only directories are allowed"
raise ClanError(msg)
# Create the tarball from the temporary directory
with TemporaryDirectory(prefix="facts-upload-") as tardir:
tar_path = Path(tardir) / "upload.tar.gz"
# We set the permissions of the files and directories in the tarball to read only and owned by root
# As first uploading the tarball and then changing the permissions can lead an attacker to
# do a race condition attack
with tarfile.open(str(tar_path), "w:gz") as tar:
for root, dirs, files in local_src.walk():
for mdir in dirs:
dir_path = Path(root) / mdir
tarinfo = tar.gettarinfo(
dir_path, arcname=str(dir_path.relative_to(str(local_src)))
)
tarinfo.mode = dir_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
tar.addfile(tarinfo)
for file in files:
file_path = Path(root) / file
tarinfo = tar.gettarinfo(
file_path,
arcname=str(file_path.relative_to(str(local_src))),
)
tarinfo.mode = file_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
with file_path.open("rb") as f:
tar.addfile(tarinfo, f)
cmd = [
*host.ssh_cmd(),
"rm",
"-r",
str(remote_dest),
";",
"mkdir",
f"--mode={dir_mode:o}",
"-p",
str(remote_dest),
"&&",
"tar",
"-C",
str(remote_dest),
"-xzf",
"-",
]
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f:
run_local(
cmd,
input=f.read(),
log=Log.BOTH,
prefix=host.command_prefix,
needs_user_terminal=True,
)

View File

@@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory
from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.machines.machines import Machine
from clan_cli.ssh.upload import upload
log = logging.getLogger(__name__)
@@ -20,8 +21,8 @@ def upload_secret_vars(machine: Machine) -> None:
with TemporaryDirectory(prefix="vars-upload-") as tempdir:
secret_dir = Path(tempdir)
secret_store.upload(secret_dir)
machine.target_host.upload(
secret_dir, Path(machine.secret_vars_upload_directory)
upload(
machine.target_host, secret_dir, Path(machine.secret_vars_upload_directory)
)

View File

@@ -18,7 +18,7 @@ class ConfigItem:
@pytest.fixture(scope="session")
def nix_config() -> dict[str, ConfigItem]:
proc = subprocess.run(
["nix", "show-config", "--json"], check=True, stdout=subprocess.PIPE
["nix", "config", "show", "--json"], check=True, stdout=subprocess.PIPE
)
data = json.loads(proc.stdout)
return {name: ConfigItem(**c) for name, c in data.items()}