clan-cli: Refactor ssh classes to dataclasses
This commit is contained in:
@@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory
|
|||||||
|
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_cli.machines.machines import Machine
|
from clan_cli.machines.machines import Machine
|
||||||
|
from clan_cli.ssh.upload import upload
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ def upload_secrets(machine: Machine) -> None:
|
|||||||
secret_facts_store.upload(local_secret_dir)
|
secret_facts_store.upload(local_secret_dir)
|
||||||
remote_secret_dir = Path(machine.secrets_upload_directory)
|
remote_secret_dir = Path(machine.secrets_upload_directory)
|
||||||
|
|
||||||
machine.target_host.upload(local_secret_dir, remote_secret_dir)
|
upload(machine.target_host, local_secret_dir, remote_secret_dir)
|
||||||
|
|
||||||
|
|
||||||
def upload_command(args: argparse.Namespace) -> None:
|
def upload_command(args: argparse.Namespace) -> None:
|
||||||
|
|||||||
@@ -5,15 +5,13 @@ import math
|
|||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
import tarfile
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shlex import quote
|
from shlex import quote
|
||||||
from tempfile import TemporaryDirectory
|
|
||||||
from typing import IO, Any
|
from typing import IO, Any
|
||||||
|
|
||||||
from clan_cli.cmd import Log
|
from clan_cli.cmd import Log
|
||||||
from clan_cli.cmd import run as local_run
|
from clan_cli.cmd import run as local_run
|
||||||
from clan_cli.errors import ClanError
|
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
from clan_cli.ssh.host_key import HostKeyCheck
|
||||||
|
|
||||||
cmdlog = logging.getLogger(__name__)
|
cmdlog = logging.getLogger(__name__)
|
||||||
@@ -23,47 +21,22 @@ cmdlog = logging.getLogger(__name__)
|
|||||||
NO_OUTPUT_TIMEOUT = 20
|
NO_OUTPUT_TIMEOUT = 20
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Host:
|
class Host:
|
||||||
def __init__(
|
host: str
|
||||||
self,
|
user: str | None = None
|
||||||
host: str,
|
port: int | None = None
|
||||||
user: str | None = None,
|
key: str | None = None
|
||||||
port: int | None = None,
|
forward_agent: bool = False
|
||||||
key: str | None = None,
|
command_prefix: str | None = None
|
||||||
forward_agent: bool = False,
|
host_key_check: HostKeyCheck = HostKeyCheck.ASK
|
||||||
command_prefix: str | None = None,
|
meta: dict[str, Any] = field(default_factory=dict)
|
||||||
host_key_check: HostKeyCheck = HostKeyCheck.ASK,
|
verbose_ssh: bool = False
|
||||||
meta: dict[str, Any] | None = None,
|
ssh_options: dict[str, str] = field(default_factory=dict)
|
||||||
verbose_ssh: bool = False,
|
|
||||||
ssh_options: dict[str, str] | None = None,
|
def __post_init__(self) -> None:
|
||||||
) -> None:
|
if not self.command_prefix:
|
||||||
"""
|
self.command_prefix = self.host
|
||||||
Creates a Host
|
|
||||||
@host the hostname to connect to via ssh
|
|
||||||
@port the port to connect to via ssh
|
|
||||||
@forward_agent: whether to forward ssh agent
|
|
||||||
@command_prefix: string to prefix each line of the command output with, defaults to host
|
|
||||||
@host_key_check: whether to check ssh host keys
|
|
||||||
@verbose_ssh: Enables verbose logging on ssh connections
|
|
||||||
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
|
|
||||||
"""
|
|
||||||
if ssh_options is None:
|
|
||||||
ssh_options = {}
|
|
||||||
if meta is None:
|
|
||||||
meta = {}
|
|
||||||
self.host = host
|
|
||||||
self.user = user
|
|
||||||
self.port = port
|
|
||||||
self.key = key
|
|
||||||
if command_prefix:
|
|
||||||
self.command_prefix = command_prefix
|
|
||||||
else:
|
|
||||||
self.command_prefix = host
|
|
||||||
self.forward_agent = forward_agent
|
|
||||||
self.host_key_check = host_key_check
|
|
||||||
self.meta = meta
|
|
||||||
self.verbose_ssh = verbose_ssh
|
|
||||||
self._ssh_options = ssh_options
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(self)
|
return str(self)
|
||||||
@@ -131,19 +104,11 @@ class Host:
|
|||||||
check: bool = True,
|
check: bool = True,
|
||||||
timeout: float = math.inf,
|
timeout: float = math.inf,
|
||||||
shell: bool = False,
|
shell: bool = False,
|
||||||
|
needs_user_terminal: bool = False,
|
||||||
log: Log = Log.BOTH,
|
log: Log = Log.BOTH,
|
||||||
) -> subprocess.CompletedProcess[str]:
|
) -> subprocess.CompletedProcess[str]:
|
||||||
"""
|
"""
|
||||||
Command to run locally for the host
|
Command to run locally for the host
|
||||||
|
|
||||||
@cmd the command to run
|
|
||||||
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocess.PIPE
|
|
||||||
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
|
|
||||||
@extra_env environment variables to override when running the command
|
|
||||||
@cwd current working directory to run the process in
|
|
||||||
@timeout: Timeout in seconds for the command to complete
|
|
||||||
|
|
||||||
@return subprocess.CompletedProcess result of the command
|
|
||||||
"""
|
"""
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
if extra_env:
|
if extra_env:
|
||||||
@@ -159,6 +124,7 @@ class Host:
|
|||||||
env=env,
|
env=env,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
check=check,
|
check=check,
|
||||||
|
needs_user_terminal=needs_user_terminal,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
log=log,
|
log=log,
|
||||||
)
|
)
|
||||||
@@ -180,17 +146,6 @@ class Host:
|
|||||||
) -> subprocess.CompletedProcess[str]:
|
) -> subprocess.CompletedProcess[str]:
|
||||||
"""
|
"""
|
||||||
Command to run on the host via ssh
|
Command to run on the host via ssh
|
||||||
|
|
||||||
@cmd the command to run
|
|
||||||
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
|
|
||||||
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
|
|
||||||
@become_root if the ssh_user is not root than sudo is prepended
|
|
||||||
@extra_env environment variables to override when running the command
|
|
||||||
@cwd current working directory to run the process in
|
|
||||||
@verbose_ssh: Enables verbose logging on ssh connections
|
|
||||||
@timeout: Timeout in seconds for the command to complete
|
|
||||||
|
|
||||||
@return subprocess.CompletedProcess result of the ssh command
|
|
||||||
"""
|
"""
|
||||||
if extra_env is None:
|
if extra_env is None:
|
||||||
extra_env = {}
|
extra_env = {}
|
||||||
@@ -246,88 +201,13 @@ class Host:
|
|||||||
env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts)
|
env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def upload(
|
|
||||||
self,
|
|
||||||
local_src: Path, # must be a directory
|
|
||||||
remote_dest: Path, # must be a directory
|
|
||||||
file_user: str = "root",
|
|
||||||
file_group: str = "root",
|
|
||||||
dir_mode: int = 0o700,
|
|
||||||
file_mode: int = 0o400,
|
|
||||||
) -> None:
|
|
||||||
# check if the remote destination is a directory (no suffix)
|
|
||||||
if remote_dest.suffix:
|
|
||||||
msg = "Only directories are allowed"
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
if not local_src.is_dir():
|
|
||||||
msg = "Only directories are allowed"
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
# Create the tarball from the temporary directory
|
|
||||||
with TemporaryDirectory(prefix="facts-upload-") as tardir:
|
|
||||||
tar_path = Path(tardir) / "upload.tar.gz"
|
|
||||||
# We set the permissions of the files and directories in the tarball to read only and owned by root
|
|
||||||
# As first uploading the tarball and then changing the permissions can lead an attacker to
|
|
||||||
# do a race condition attack
|
|
||||||
with tarfile.open(str(tar_path), "w:gz") as tar:
|
|
||||||
for root, dirs, files in local_src.walk():
|
|
||||||
for mdir in dirs:
|
|
||||||
dir_path = Path(root) / mdir
|
|
||||||
tarinfo = tar.gettarinfo(
|
|
||||||
dir_path, arcname=str(dir_path.relative_to(str(local_src)))
|
|
||||||
)
|
|
||||||
tarinfo.mode = dir_mode
|
|
||||||
tarinfo.uname = file_user
|
|
||||||
tarinfo.gname = file_group
|
|
||||||
tar.addfile(tarinfo)
|
|
||||||
for file in files:
|
|
||||||
file_path = Path(root) / file
|
|
||||||
tarinfo = tar.gettarinfo(
|
|
||||||
file_path,
|
|
||||||
arcname=str(file_path.relative_to(str(local_src))),
|
|
||||||
)
|
|
||||||
tarinfo.mode = file_mode
|
|
||||||
tarinfo.uname = file_user
|
|
||||||
tarinfo.gname = file_group
|
|
||||||
with file_path.open("rb") as f:
|
|
||||||
tar.addfile(tarinfo, f)
|
|
||||||
|
|
||||||
cmd = [
|
|
||||||
*self.ssh_cmd(),
|
|
||||||
"rm",
|
|
||||||
"-r",
|
|
||||||
str(remote_dest),
|
|
||||||
";",
|
|
||||||
"mkdir",
|
|
||||||
f"--mode={dir_mode:o}",
|
|
||||||
"-p",
|
|
||||||
str(remote_dest),
|
|
||||||
"&&",
|
|
||||||
"tar",
|
|
||||||
"-C",
|
|
||||||
str(remote_dest),
|
|
||||||
"-xzf",
|
|
||||||
"-",
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
|
|
||||||
with tar_path.open("rb") as f:
|
|
||||||
local_run(
|
|
||||||
cmd,
|
|
||||||
input=f.read(),
|
|
||||||
log=Log.BOTH,
|
|
||||||
needs_user_terminal=True,
|
|
||||||
prefix=self.command_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ssh_cmd_opts(
|
def ssh_cmd_opts(
|
||||||
self,
|
self,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
ssh_opts = ["-A"] if self.forward_agent else []
|
ssh_opts = ["-A"] if self.forward_agent else []
|
||||||
|
|
||||||
for k, v in self._ssh_options.items():
|
for k, v in self.ssh_options.items():
|
||||||
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
|
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
|
||||||
|
|
||||||
ssh_opts.extend(self.host_key_check.to_ssh_opt())
|
ssh_opts.extend(self.host_key_check.to_ssh_opt())
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import IO, Any
|
from typing import IO, Any
|
||||||
@@ -26,15 +27,9 @@ def _worker(
|
|||||||
results[idx] = HostResult(host, e)
|
results[idx] = HostResult(host, e)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class HostGroup:
|
class HostGroup:
|
||||||
def __init__(self, hosts: list[Host]) -> None:
|
hosts: list[Host]
|
||||||
self.hosts = hosts
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return str(self)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"HostGroup({self.hosts})"
|
|
||||||
|
|
||||||
def _run_local(
|
def _run_local(
|
||||||
self,
|
self,
|
||||||
@@ -189,11 +184,6 @@ class HostGroup:
|
|||||||
) -> Results:
|
) -> Results:
|
||||||
"""
|
"""
|
||||||
Command to run on the remote host via ssh
|
Command to run on the remote host via ssh
|
||||||
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
|
|
||||||
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
|
|
||||||
@cwd current working directory to run the process in
|
|
||||||
@verbose_ssh: Enables verbose logging on ssh connections
|
|
||||||
@timeout: Timeout in seconds for the command to complete
|
|
||||||
|
|
||||||
@return a lists of tuples containing Host and the result of the command for this Host
|
@return a lists of tuples containing Host and the result of the command for this Host
|
||||||
"""
|
"""
|
||||||
@@ -227,12 +217,6 @@ class HostGroup:
|
|||||||
) -> Results:
|
) -> Results:
|
||||||
"""
|
"""
|
||||||
Command to run locally for each host in the group in parallel
|
Command to run locally for each host in the group in parallel
|
||||||
@cmd the command to run
|
|
||||||
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
|
|
||||||
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
|
|
||||||
@cwd current working directory to run the process in
|
|
||||||
@extra_env environment variables to override when running the command
|
|
||||||
@timeout: Timeout in seconds for the command to complete
|
|
||||||
|
|
||||||
@return a lists of tuples containing Host and the result of the command for this Host
|
@return a lists of tuples containing Host and the result of the command for this Host
|
||||||
"""
|
"""
|
||||||
@@ -256,8 +240,6 @@ class HostGroup:
|
|||||||
) -> list[HostResult[T]]:
|
) -> list[HostResult[T]]:
|
||||||
"""
|
"""
|
||||||
Function to run for each host in the group in parallel
|
Function to run for each host in the group in parallel
|
||||||
|
|
||||||
@func the function to call
|
|
||||||
"""
|
"""
|
||||||
threads = []
|
threads = []
|
||||||
results: list[HostResult[T]] = [
|
results: list[HostResult[T]] = [
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
|
|
||||||
from clan_cli.ssh import T
|
from clan_cli.ssh import T
|
||||||
from clan_cli.ssh.host import Host
|
from clan_cli.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class HostResult(Generic[T]):
|
class HostResult(Generic[T]):
|
||||||
def __init__(self, host: Host, result: T | Exception) -> None:
|
host: Host
|
||||||
self.host = host
|
_result: T | Exception
|
||||||
self._result = result
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def error(self) -> Exception | None:
|
def error(self) -> Exception | None:
|
||||||
|
|||||||
84
pkgs/clan-cli/clan_cli/ssh/upload.py
Normal file
84
pkgs/clan-cli/clan_cli/ssh/upload.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from clan_cli.cmd import Log
|
||||||
|
from clan_cli.cmd import run as run_local
|
||||||
|
from clan_cli.errors import ClanError
|
||||||
|
from clan_cli.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
|
def upload(
|
||||||
|
host: Host,
|
||||||
|
local_src: Path, # must be a directory
|
||||||
|
remote_dest: Path, # must be a directory
|
||||||
|
file_user: str = "root",
|
||||||
|
file_group: str = "root",
|
||||||
|
dir_mode: int = 0o700,
|
||||||
|
file_mode: int = 0o400,
|
||||||
|
) -> None:
|
||||||
|
# check if the remote destination is a directory (no suffix)
|
||||||
|
if remote_dest.suffix:
|
||||||
|
msg = "Only directories are allowed"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
if not local_src.is_dir():
|
||||||
|
msg = "Only directories are allowed"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
# Create the tarball from the temporary directory
|
||||||
|
with TemporaryDirectory(prefix="facts-upload-") as tardir:
|
||||||
|
tar_path = Path(tardir) / "upload.tar.gz"
|
||||||
|
# We set the permissions of the files and directories in the tarball to read only and owned by root
|
||||||
|
# As first uploading the tarball and then changing the permissions can lead an attacker to
|
||||||
|
# do a race condition attack
|
||||||
|
with tarfile.open(str(tar_path), "w:gz") as tar:
|
||||||
|
for root, dirs, files in local_src.walk():
|
||||||
|
for mdir in dirs:
|
||||||
|
dir_path = Path(root) / mdir
|
||||||
|
tarinfo = tar.gettarinfo(
|
||||||
|
dir_path, arcname=str(dir_path.relative_to(str(local_src)))
|
||||||
|
)
|
||||||
|
tarinfo.mode = dir_mode
|
||||||
|
tarinfo.uname = file_user
|
||||||
|
tarinfo.gname = file_group
|
||||||
|
tar.addfile(tarinfo)
|
||||||
|
for file in files:
|
||||||
|
file_path = Path(root) / file
|
||||||
|
tarinfo = tar.gettarinfo(
|
||||||
|
file_path,
|
||||||
|
arcname=str(file_path.relative_to(str(local_src))),
|
||||||
|
)
|
||||||
|
tarinfo.mode = file_mode
|
||||||
|
tarinfo.uname = file_user
|
||||||
|
tarinfo.gname = file_group
|
||||||
|
with file_path.open("rb") as f:
|
||||||
|
tar.addfile(tarinfo, f)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
*host.ssh_cmd(),
|
||||||
|
"rm",
|
||||||
|
"-r",
|
||||||
|
str(remote_dest),
|
||||||
|
";",
|
||||||
|
"mkdir",
|
||||||
|
f"--mode={dir_mode:o}",
|
||||||
|
"-p",
|
||||||
|
str(remote_dest),
|
||||||
|
"&&",
|
||||||
|
"tar",
|
||||||
|
"-C",
|
||||||
|
str(remote_dest),
|
||||||
|
"-xzf",
|
||||||
|
"-",
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
|
||||||
|
with tar_path.open("rb") as f:
|
||||||
|
run_local(
|
||||||
|
cmd,
|
||||||
|
input=f.read(),
|
||||||
|
log=Log.BOTH,
|
||||||
|
prefix=host.command_prefix,
|
||||||
|
needs_user_terminal=True,
|
||||||
|
)
|
||||||
@@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory
|
|||||||
|
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_cli.machines.machines import Machine
|
from clan_cli.machines.machines import Machine
|
||||||
|
from clan_cli.ssh.upload import upload
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -20,8 +21,8 @@ def upload_secret_vars(machine: Machine) -> None:
|
|||||||
with TemporaryDirectory(prefix="vars-upload-") as tempdir:
|
with TemporaryDirectory(prefix="vars-upload-") as tempdir:
|
||||||
secret_dir = Path(tempdir)
|
secret_dir = Path(tempdir)
|
||||||
secret_store.upload(secret_dir)
|
secret_store.upload(secret_dir)
|
||||||
machine.target_host.upload(
|
upload(
|
||||||
secret_dir, Path(machine.secret_vars_upload_directory)
|
machine.target_host, secret_dir, Path(machine.secret_vars_upload_directory)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class ConfigItem:
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def nix_config() -> dict[str, ConfigItem]:
|
def nix_config() -> dict[str, ConfigItem]:
|
||||||
proc = subprocess.run(
|
proc = subprocess.run(
|
||||||
["nix", "show-config", "--json"], check=True, stdout=subprocess.PIPE
|
["nix", "config", "show", "--json"], check=True, stdout=subprocess.PIPE
|
||||||
)
|
)
|
||||||
data = json.loads(proc.stdout)
|
data = json.loads(proc.stdout)
|
||||||
return {name: ConfigItem(**c) for name, c in data.items()}
|
return {name: ConfigItem(**c) for name, c in data.items()}
|
||||||
|
|||||||
Reference in New Issue
Block a user