clan-cli: Rename Host -> Remote move to clan_lib and mark as frozen

This commit is contained in:
Qubasa
2025-05-22 14:08:27 +02:00
parent 91994445ff
commit e14f30bdc0
31 changed files with 453 additions and 429 deletions

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from clan_lib.ssh.remote import Remote
import clan_cli.machines.machines as machines import clan_cli.machines.machines as machines
from clan_cli.ssh.host import Host
class SecretStoreBase(ABC): class SecretStoreBase(ABC):
@@ -26,7 +27,7 @@ class SecretStoreBase(ABC):
def exists(self, service: str, name: str) -> bool: def exists(self, service: str, name: str) -> bool:
pass pass
def needs_upload(self, host: Host) -> bool: def needs_upload(self, host: Remote) -> bool:
return True return True
@abstractmethod @abstractmethod

View File

@@ -5,11 +5,10 @@ from typing import override
from clan_lib.cmd import Log, RunOpts from clan_lib.cmd import Log, RunOpts
from clan_lib.nix import nix_shell from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import Remote
from clan_cli.facts.secret_modules import SecretStoreBase
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from . import SecretStoreBase
class SecretStore(SecretStoreBase): class SecretStore(SecretStoreBase):
@@ -95,13 +94,14 @@ class SecretStore(SecretStoreBase):
return b"\n".join(hashes) return b"\n".join(hashes)
@override @override
def needs_upload(self, host: Host) -> bool: def needs_upload(self, host: Remote) -> bool:
local_hash = self.generate_hash() local_hash = self.generate_hash()
remote_hash = host.run( with host.ssh_control_master() as ssh:
# TODO get the path to the secrets from the machine remote_hash = ssh.run(
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"], # TODO get the path to the secrets from the machine
RunOpts(log=Log.STDERR, check=False), ["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
).stdout.strip() RunOpts(log=Log.STDERR, check=False),
).stdout.strip()
if not remote_hash: if not remote_hash:
print("remote hash is empty") print("remote hash is empty")

View File

@@ -1,12 +1,13 @@
from pathlib import Path from pathlib import Path
from typing import override from typing import override
from clan_lib.ssh.remote import Remote
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.secrets.folders import sops_secrets_folder from clan_cli.secrets.folders import sops_secrets_folder
from clan_cli.secrets.machines import add_machine, has_machine from clan_cli.secrets.machines import add_machine, has_machine
from clan_cli.secrets.secrets import decrypt_secret, encrypt_secret, has_secret from clan_cli.secrets.secrets import decrypt_secret, encrypt_secret, has_secret
from clan_cli.secrets.sops import generate_private_key from clan_cli.secrets.sops import generate_private_key
from clan_cli.ssh.host import Host
from . import SecretStoreBase from . import SecretStoreBase
@@ -61,7 +62,7 @@ class SecretStore(SecretStoreBase):
) )
@override @override
def needs_upload(self, host: Host) -> bool: def needs_upload(self, host: Remote) -> bool:
return False return False
# We rely now on the vars backend to upload the age key # We rely now on the vars backend to upload the age key

View File

@@ -11,16 +11,17 @@ log = logging.getLogger(__name__)
def upload_secrets(machine: Machine) -> None: def upload_secrets(machine: Machine) -> None:
with machine.target_host() as host: host = machine.target_host()
if not machine.secret_facts_store.needs_upload(host):
machine.info("Secrets already uploaded")
return
with TemporaryDirectory(prefix="facts-upload-") as _tempdir: if not machine.secret_facts_store.needs_upload(host):
local_secret_dir = Path(_tempdir).resolve() machine.info("Secrets already uploaded")
machine.secret_facts_store.upload(local_secret_dir) return
remote_secret_dir = Path(machine.secrets_upload_directory)
upload(host, local_secret_dir, remote_secret_dir) with TemporaryDirectory(prefix="facts-upload-") as _tempdir:
local_secret_dir = Path(_tempdir).resolve()
machine.secret_facts_store.upload(local_secret_dir)
remote_secret_dir = Path(machine.secrets_upload_directory)
upload(host, local_secret_dir, remote_secret_dir)
def upload_command(args: argparse.Namespace) -> None: def upload_command(args: argparse.Namespace) -> None:

View File

@@ -103,26 +103,23 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
"--show-hardware-config", "--show-hardware-config",
] ]
with machine.target_host() as host: host = opts.machine.target_host()
host.ssh_options["StrictHostKeyChecking"] = "accept-new"
host.ssh_options["UserKnownHostsFile"] = "/dev/null"
if opts.password:
host.password = opts.password
out = host.run(config_command, become_root=True, opts=RunOpts(check=False)) with host.ssh_control_master() as ssh:
if out.returncode != 0: out = ssh.run(config_command, become_root=True, opts=RunOpts(check=False))
if "nixos-facter" in out.stderr and "not found" in out.stderr: if out.returncode != 0:
machine.error(str(out.stderr)) if "nixos-facter" in out.stderr and "not found" in out.stderr:
msg = ( machine.error(str(out.stderr))
"Please use our custom nixos install images from https://github.com/nix-community/nixos-images/releases/tag/nixos-unstable. " msg = (
"nixos-factor only works on nixos / clan systems currently." "Please use our custom nixos install images from https://github.com/nix-community/nixos-images/releases/tag/nixos-unstable. "
) "nixos-factor only works on nixos / clan systems currently."
raise ClanError(msg) )
machine.error(str(out))
msg = f"Failed to inspect {opts.machine}. Address: {host.target}"
raise ClanError(msg) raise ClanError(msg)
machine.error(str(out))
msg = f"Failed to inspect {opts.machine}. Address: {host.target}"
raise ClanError(msg)
backup_file = None backup_file = None
if hw_file.exists(): if hw_file.exists():
backup_file = hw_file.with_suffix(".bak") backup_file = hw_file.with_suffix(".bak")

View File

@@ -57,9 +57,9 @@ def install_machine(opts: InstallOptions) -> None:
generate_facts([machine]) generate_facts([machine])
generate_vars([machine]) generate_vars([machine])
host = machine.target_host()
with ( with (
TemporaryDirectory(prefix="nixos-install-") as _base_directory, TemporaryDirectory(prefix="nixos-install-") as _base_directory,
machine.target_host() as host,
): ):
base_directory = Path(_base_directory).resolve() base_directory = Path(_base_directory).resolve()
activation_secrets = base_directory / "activation_secrets" activation_secrets = base_directory / "activation_secrets"

View File

@@ -2,8 +2,6 @@ import importlib
import json import json
import logging import logging
import re import re
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
@@ -12,12 +10,11 @@ from typing import TYPE_CHECKING, Any
from clan_lib.errors import ClanCmdError, ClanError from clan_lib.errors import ClanCmdError, ClanError
from clan_lib.flake import Flake from clan_lib.flake import Flake
from clan_lib.nix import nix_config, nix_test_store from clan_lib.nix import nix_config, nix_test_store
from clan_lib.ssh.remote import Remote
from clan_cli.facts import public_modules as facts_public_modules from clan_cli.facts import public_modules as facts_public_modules
from clan_cli.facts import secret_modules as facts_secret_modules from clan_cli.facts import secret_modules as facts_secret_modules
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.parse import parse_deployment_address
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -146,37 +143,31 @@ class Machine:
def flake_dir(self) -> Path: def flake_dir(self) -> Path:
return self.flake.path return self.flake.path
@contextmanager def target_host(self) -> Remote:
def target_host(self) -> Iterator[Host]: return Remote.from_deployment_address(
with parse_deployment_address( machine_name=self.name,
self.name, address=self.target_host_address,
self.target_host_address, host_key_check=self.host_key_check,
self.host_key_check,
private_key=self.private_key, private_key=self.private_key,
meta={"machine": self}, )
) as target_host:
yield target_host
@contextmanager def build_host(self) -> Remote | None:
def build_host(self) -> Iterator[Host | None]:
""" """
The host where the machine is built and deployed from. The host where the machine is built and deployed from.
Can be the same as the target host. Can be the same as the target host.
""" """
build_host = self.override_build_host or self.deployment.get("buildHost") address = self.override_build_host or self.deployment.get("buildHost")
if build_host is None: if address is None:
yield None return None
return
# enable ssh agent forwarding to allow the build host to access the target host # enable ssh agent forwarding to allow the build host to access the target host
with parse_deployment_address( host = Remote.from_deployment_address(
self.name, machine_name=self.name,
build_host, address=address,
self.host_key_check, host_key_check=self.host_key_check,
forward_agent=True, forward_agent=True,
private_key=self.private_key, private_key=self.private_key,
meta={"machine": self}, )
) as build_host: return host
yield build_host
def nix( def nix(
self, self,

View File

@@ -5,7 +5,6 @@ import os
import re import re
import shlex import shlex
import sys import sys
from contextlib import ExitStack
from clan_lib.api import API from clan_lib.api import API
from clan_lib.async_run import AsyncContext, AsyncOpts, AsyncRuntime, is_async_cancelled from clan_lib.async_run import AsyncContext, AsyncOpts, AsyncRuntime, is_async_cancelled
@@ -13,6 +12,7 @@ from clan_lib.cmd import Log, MsgColor, RunOpts, run
from clan_lib.colors import AnsiColor from clan_lib.colors import AnsiColor
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.nix import nix_command, nix_config, nix_metadata from clan_lib.nix import nix_command, nix_config, nix_metadata
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
@@ -22,7 +22,6 @@ from clan_cli.facts.generate import generate_facts
from clan_cli.facts.upload import upload_secrets from clan_cli.facts.upload import upload_secrets
from clan_cli.machines.list import list_machines from clan_cli.machines.list import list_machines
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host, HostKeyCheck
from clan_cli.vars.generate import generate_vars from clan_cli.vars.generate import generate_vars
from clan_cli.vars.upload import upload_secret_vars from clan_cli.vars.upload import upload_secret_vars
@@ -44,161 +43,160 @@ def is_local_input(node: dict[str, dict[str, str]]) -> bool:
return local return local
def upload_sources(machine: Machine, host: Host) -> str: def upload_sources(machine: Machine, host: Remote) -> str:
env = host.nix_ssh_env(os.environ.copy()) with host.ssh_control_master() as ssh:
env = ssh.nix_ssh_env(os.environ.copy())
flake_url = ( flake_url = (
str(machine.flake.path) if machine.flake.is_local else machine.flake.identifier str(machine.flake.path)
) if machine.flake.is_local
flake_data = nix_metadata(flake_url) else machine.flake.identifier
has_path_inputs = any( )
is_local_input(node) for node in flake_data["locks"]["nodes"].values() flake_data = nix_metadata(flake_url)
) has_path_inputs = any(
is_local_input(node) for node in flake_data["locks"]["nodes"].values()
)
if not has_path_inputs: if not has_path_inputs:
# Just copy the flake to the remote machine, we can substitute other inputs there. # Just copy the flake to the remote machine, we can substitute other inputs there.
path = flake_data["path"] path = flake_data["path"]
cmd = nix_command(
remote_url = f"ssh-ng://{host.target}" [
"copy",
# MacOS doesn't come with a proper login shell for ssh and therefore doesn't have nix in $PATH as it doesn't source /etc/profile "--to",
if machine._class_ == "darwin": f"ssh://{host.target}",
remote_url += "?remote-program=bash -lc 'exec nix-daemon --stdio'" "--no-check-sigs",
path,
]
)
run(
cmd,
RunOpts(
env=env,
needs_user_terminal=True,
error_msg="failed to upload sources",
prefix=machine.name,
),
)
return path
# Slow path: we need to upload all sources to the remote machine
cmd = nix_command( cmd = nix_command(
[ [
"copy", "flake",
"archive",
"--to", "--to",
remote_url, f"ssh://{host.target}",
"--no-check-sigs", "--json",
path, flake_url,
] ]
) )
run( proc = run(
cmd, cmd,
RunOpts( RunOpts(
env=env, env=env, needs_user_terminal=True, error_msg="failed to upload sources"
needs_user_terminal=True,
error_msg="failed to upload sources",
prefix=machine.name,
), ),
) )
return path
# Slow path: we need to upload all sources to the remote machine try:
cmd = nix_command( return json.loads(proc.stdout)["path"]
[ except (json.JSONDecodeError, OSError) as e:
"flake", msg = (
"archive", f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout}"
"--to", )
f"ssh://{host.target}", raise ClanError(msg) from e
"--json",
flake_url,
]
)
proc = run(
cmd,
RunOpts(
env=env, needs_user_terminal=True, error_msg="failed to upload sources"
),
)
try:
return json.loads(proc.stdout)["path"]
except (json.JSONDecodeError, OSError) as e:
msg = f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout}"
raise ClanError(msg) from e
@API.register @API.register
def deploy_machine(machine: Machine) -> None: def deploy_machine(machine: Machine) -> None:
with ExitStack() as stack: target_host = machine.target_host()
target_host = stack.enter_context(machine.target_host()) build_host = machine.build_host()
build_host = stack.enter_context(machine.build_host())
host = build_host or target_host host = build_host or target_host
generate_facts([machine], service=None, regenerate=False) generate_facts([machine], service=None, regenerate=False)
generate_vars([machine], generator_name=None, regenerate=False) generate_vars([machine], generator_name=None, regenerate=False)
upload_secrets(machine) upload_secrets(machine)
upload_secret_vars(machine, target_host) upload_secret_vars(machine, target_host)
path = upload_sources(machine, host) path = upload_sources(machine, host)
nix_options = [ nix_options = [
"--show-trace", "--show-trace",
"--option", "--option",
"keep-going", "keep-going",
"true", "true",
"--option", "--option",
"accept-flake-config", "accept-flake-config",
"true", "true",
"-L", "-L",
*machine.nix_options, *machine.nix_options,
"--flake", "--flake",
f"{path}#{machine.name}", f"{path}#{machine.name}",
]
become_root = True
if machine._class_ == "nixos":
nix_options += [
"--fast",
"--build-host",
"",
] ]
become_root = True if build_host:
become_root = False
nix_options += ["--target-host", target_host.target]
if machine._class_ == "nixos": if target_host.user != "root":
nix_options += [ nix_options += ["--use-remote-sudo"]
"--fast", switch_cmd = ["nixos-rebuild", "switch", *nix_options]
"--build-host", elif machine._class_ == "darwin":
"", # use absolute path to darwin-rebuild
] switch_cmd = [
"/run/current-system/sw/bin/darwin-rebuild",
"switch",
*nix_options,
]
if build_host: remote_env = host.nix_ssh_env(control_master=False)
become_root = False ret = host.run(
nix_options += ["--target-host", target_host.target] switch_cmd,
RunOpts(
check=False,
log=Log.BOTH,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
needs_user_terminal=True,
),
extra_env=remote_env,
become_root=become_root,
control_master=False,
)
if target_host.user != "root": if is_async_cancelled():
nix_options += ["--use-remote-sudo"] return
switch_cmd = ["nixos-rebuild", "switch", *nix_options]
elif machine._class_ == "darwin":
# use absolute path to darwin-rebuild
switch_cmd = [
"/run/current-system/sw/bin/darwin-rebuild",
"switch",
*nix_options,
]
remote_env = host.nix_ssh_env(None, local_ssh=False) # retry nixos-rebuild switch if the first attempt failed
if ret.returncode != 0:
is_mobile = machine.deployment.get("nixosMobileWorkaround", False)
# if the machine is mobile, we retry to deploy with the mobile workaround method
if is_mobile:
machine.info(
"Mobile machine detected, applying workaround deployment method"
)
ret = host.run( ret = host.run(
switch_cmd, ["nixos--rebuild", "test", *nix_options] if is_mobile else switch_cmd,
RunOpts( RunOpts(
check=False,
log=Log.BOTH, log=Log.BOTH,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT), msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
needs_user_terminal=True, needs_user_terminal=True,
), ),
extra_env=remote_env, extra_env=remote_env,
become_root=become_root, become_root=become_root,
control_master=False,
) )
if is_async_cancelled():
return
# retry nixos-rebuild switch if the first attempt failed
if ret.returncode != 0:
is_mobile = machine.deployment.get("nixosMobileWorkaround", False)
# if the machine is mobile, we retry to deploy with the mobile workaround method
if is_mobile:
machine.info(
"Mobile machine detected, applying workaround deployment method"
)
ret = host.run(
["nixos--rebuild", "test", *nix_options] if is_mobile else switch_cmd,
RunOpts(
log=Log.BOTH,
msg_color=MsgColor(stderr=AnsiColor.DEFAULT),
needs_user_terminal=True,
),
extra_env=remote_env,
become_root=become_root,
)
def deploy_machines(machines: list[Machine]) -> None: def deploy_machines(machines: list[Machine]) -> None:
""" """

View File

@@ -10,15 +10,15 @@ from clan_lib.async_run import AsyncRuntime
from clan_lib.cmd import run from clan_lib.cmd import run
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.nix import nix_shell from clan_lib.nix import nix_shell
from clan_lib.ssh.parse import parse_deployment_address
from clan_lib.ssh.remote import Remote, is_ssh_reachable
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
complete_machines, complete_machines,
) )
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host, is_ssh_reachable
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.parse import parse_deployment_address
from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -51,12 +51,12 @@ def is_ipv6(ip: str) -> bool:
def find_reachable_host( def find_reachable_host(
deploy_info: DeployInfo, host_key_check: HostKeyCheck deploy_info: DeployInfo, host_key_check: HostKeyCheck
) -> Host | None: ) -> Remote | None:
host = None host = None
for addr in deploy_info.addrs: for addr in deploy_info.addrs:
host_addr = f"[{addr}]" if is_ipv6(addr) else addr host_addr = f"[{addr}]" if is_ipv6(addr) else addr
host_ = parse_deployment_address( host_ = parse_deployment_address(
machine_name="uknown", host=host_addr, host_key_check=host_key_check machine_name="uknown", address=host_addr, host_key_check=host_key_check
) )
if is_ssh_reachable(host_): if is_ssh_reachable(host_):
host = host_ host = host_
@@ -88,7 +88,8 @@ def ssh_shell_from_deploy(
deploy_info: DeployInfo, runtime: AsyncRuntime, host_key_check: HostKeyCheck deploy_info: DeployInfo, runtime: AsyncRuntime, host_key_check: HostKeyCheck
) -> None: ) -> None:
if host := find_reachable_host(deploy_info, host_key_check): if host := find_reachable_host(deploy_info, host_key_check):
host.interactive_ssh() with host.ssh_control_master() as ssh:
ssh.interactive_ssh()
else: else:
log.info("Could not reach host via clearnet 'addrs'") log.info("Could not reach host via clearnet 'addrs'")
log.info(f"Trying to reach host via tor '{deploy_info.tor}'") log.info(f"Trying to reach host via tor '{deploy_info.tor}'")
@@ -97,7 +98,13 @@ def ssh_shell_from_deploy(
msg = "No tor address provided, please provide a tor address." msg = "No tor address provided, please provide a tor address."
raise ClanError(msg) raise ClanError(msg)
if ssh_tor_reachable(TorTarget(onion=deploy_info.tor, port=22)): if ssh_tor_reachable(TorTarget(onion=deploy_info.tor, port=22)):
host = Host(host=deploy_info.tor, password=deploy_info.pwd, tor_socks=True) host = Remote(
address=deploy_info.tor,
user="root",
password=deploy_info.pwd,
tor_socks=True,
command_prefix="tor",
)
else: else:
msg = "Could not reach host via tor either." msg = "Could not reach host via tor either."
raise ClanError(msg) raise ClanError(msg)

View File

@@ -2,14 +2,14 @@ from dataclasses import dataclass
from typing import Generic from typing import Generic
from clan_lib.errors import CmdOut from clan_lib.errors import CmdOut
from clan_lib.ssh.remote import Remote
from clan_cli.ssh import T from clan_cli.ssh import T
from clan_cli.ssh.host import Host
@dataclass @dataclass
class HostResult(Generic[T]): class HostResult(Generic[T]):
host: Host host: Remote
_result: T | Exception _result: T | Exception
@property @property

View File

@@ -4,12 +4,11 @@ from tempfile import TemporaryDirectory
from clan_lib.cmd import Log, RunOpts from clan_lib.cmd import Log, RunOpts
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
from clan_cli.ssh.host import Host
def upload( def upload(
host: Host, host: Remote,
local_src: Path, local_src: Path,
remote_dest: Path, # must be a directory remote_dest: Path, # must be a directory
file_user: str = "root", file_user: str = "root",
@@ -99,8 +98,8 @@ def upload(
raise ClanError(msg) raise ClanError(msg)
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory. # TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f: with tar_path.open("rb") as f, host.ssh_control_master() as ssh:
host.run( ssh.run(
[ [
"bash", "bash",
"-c", "-c",

View File

@@ -3,21 +3,22 @@ import pwd
from pathlib import Path from pathlib import Path
import pytest import pytest
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.tests.sshd import Sshd from clan_cli.tests.sshd import Sshd
from clan_lib.ssh.remote import Remote
@pytest.fixture @pytest.fixture
def hosts(sshd: Sshd) -> list[Host]: def hosts(sshd: Sshd) -> list[Remote]:
login = pwd.getpwuid(os.getuid()).pw_name login = pwd.getpwuid(os.getuid()).pw_name
group = [ group = [
Host( Remote(
"127.0.0.1", "127.0.0.1",
port=sshd.port, port=sshd.port,
user=login, user=login,
private_key=Path(sshd.key), private_key=Path(sshd.key),
host_key_check=HostKeyCheck.NONE, host_key_check=HostKeyCheck.NONE,
command_prefix="local_test",
) )
] ]

View File

@@ -1,8 +1,8 @@
from clan_cli.ssh.host import Host
from clan_lib.async_run import AsyncRuntime from clan_lib.async_run import AsyncRuntime
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_lib.ssh.remote import Remote
host = Host("some_host") host = Remote("some_host", user="root", command_prefix="local_test")
def test_run_environment(runtime: AsyncRuntime) -> None: def test_run_environment(runtime: AsyncRuntime) -> None:

View File

@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Any, NamedTuple from typing import Any, NamedTuple
import pytest import pytest
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.parse import parse_deployment_address
from clan_lib.async_run import AsyncRuntime from clan_lib.async_run import AsyncRuntime
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_lib.errors import ClanError, CmdOut from clan_lib.errors import ClanError, CmdOut
from clan_lib.ssh.remote import Remote
if sys.platform == "darwin": if sys.platform == "darwin":
pytest.skip("preload doesn't work on darwin", allow_module_level=True) pytest.skip("preload doesn't work on darwin", allow_module_level=True)
@@ -110,12 +109,16 @@ def test_parse_deployment_address(
with maybe_check_exception: with maybe_check_exception:
machine_name = "foo" machine_name = "foo"
result = parse_deployment_address(machine_name, test_addr, HostKeyCheck.STRICT) result = Remote.from_deployment_address(
machine_name=machine_name,
address=test_addr,
host_key_check=HostKeyCheck.STRICT,
)
if expected_exception: if expected_exception:
return return
assert result.host == expected_host assert result.address == expected_host
assert result.port == expected_port assert result.port == expected_port
assert result.user == expected_user or ( assert result.user == expected_user or (
expected_user == "" and result.user == "root" expected_user == "" and result.user == "root"
@@ -126,16 +129,18 @@ def test_parse_deployment_address(
def test_parse_ssh_options() -> None: def test_parse_ssh_options() -> None:
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictHostKeyChecking=yes" addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
host = parse_deployment_address("foo", addr, HostKeyCheck.STRICT) host = Remote.from_deployment_address(
assert host.host == "example.com" machine_name="foo", address=addr, host_key_check=HostKeyCheck.STRICT
)
assert host.address == "example.com"
assert host.port == 2222 assert host.port == 2222
assert host.user == "root" assert host.user == "root"
assert host.ssh_options["IdentityFile"] == "/path/to/private/key" assert host.ssh_options["IdentityFile"] == "/path/to/private/key"
assert host.ssh_options["StrictHostKeyChecking"] == "yes" assert host.ssh_options["StrictRemoteKeyChecking"] == "yes"
def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_run(hosts: list[Remote], runtime: AsyncRuntime) -> None:
for host in hosts: for host in hosts:
proc = runtime.async_run( proc = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR) None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
@@ -143,7 +148,7 @@ def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None:
assert proc.wait().result.stdout == "hello\n" assert proc.wait().result.stdout == "hello\n"
def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_run_environment(hosts: list[Remote], runtime: AsyncRuntime) -> None:
for host in hosts: for host in hosts:
proc = runtime.async_run( proc = runtime.async_run(
None, None,
@@ -165,7 +170,7 @@ def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None:
assert "env_var=true" in p2.wait().result.stdout assert "env_var=true" in p2.wait().result.stdout
def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_run_no_shell(hosts: list[Remote], runtime: AsyncRuntime) -> None:
for host in hosts: for host in hosts:
proc = runtime.async_run( proc = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR) None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
@@ -173,9 +178,10 @@ def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None:
assert proc.wait().result.stdout == "hello\n" assert proc.wait().result.stdout == "hello\n"
def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_run_function(hosts: list[Remote], runtime: AsyncRuntime) -> None:
def some_func(h: Host) -> bool: def some_func(h: Remote) -> bool:
p = h.run(["echo", "hello"]) with h.ssh_control_master() as ssh:
p = ssh.run(["echo", "hello"])
return p.stdout == "hello\n" return p.stdout == "hello\n"
for host in hosts: for host in hosts:
@@ -183,7 +189,7 @@ def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None:
assert proc.wait().result assert proc.wait().result
def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_timeout(hosts: list[Remote], runtime: AsyncRuntime) -> None:
for host in hosts: for host in hosts:
proc = runtime.async_run( proc = runtime.async_run(
None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01) None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01)
@@ -192,7 +198,7 @@ def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None:
assert isinstance(error, ClanCmdTimeoutError) assert isinstance(error, ClanCmdTimeoutError)
def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_run_exception(hosts: list[Remote], runtime: AsyncRuntime) -> None:
for host in hosts: for host in hosts:
proc = runtime.async_run( proc = runtime.async_run(
None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False) None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False)
@@ -211,8 +217,8 @@ def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
raise AssertionError(msg) raise AssertionError(msg)
def test_run_function_exception(hosts: list[Host], runtime: AsyncRuntime) -> None: def test_run_function_exception(hosts: list[Remote], runtime: AsyncRuntime) -> None:
def some_func(h: Host) -> CmdOut: def some_func(h: Remote) -> CmdOut:
return h.run_local(["exit 1"], RunOpts(shell=True)) return h.run_local(["exit 1"], RunOpts(shell=True))
try: try:

View File

@@ -1,18 +1,17 @@
from pathlib import Path from pathlib import Path
import pytest import pytest
from clan_cli.ssh.host import Host, HostKeyCheck
from clan_cli.ssh.upload import upload from clan_cli.ssh.upload import upload
from clan_lib.ssh.remote import Remote
@pytest.mark.with_core @pytest.mark.with_core
def test_upload_single_file( def test_upload_single_file(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
temporary_home: Path, temporary_home: Path,
hosts: list[Host], hosts: list[Remote],
) -> None: ) -> None:
host = hosts[0] host = hosts[0]
host.host_key_check = HostKeyCheck.NONE
src_file = temporary_home / "test.txt" src_file = temporary_home / "test.txt"
src_file.write_text("test") src_file.write_text("test")

View File

@@ -6,8 +6,8 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from clan_cli.machines import machines from clan_cli.machines import machines
from clan_cli.ssh.host import Host
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
if TYPE_CHECKING: if TYPE_CHECKING:
from .generate import Generator, Var from .generate import Generator, Var
@@ -184,5 +184,5 @@ class StoreBase(ABC):
pass pass
@abstractmethod @abstractmethod
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
pass pass

View File

@@ -3,10 +3,10 @@ from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
from clan_cli.vars.generate import Generator, Var from clan_cli.vars.generate import Generator, Var
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
class FactStore(StoreBase): class FactStore(StoreBase):
@@ -73,6 +73,6 @@ class FactStore(StoreBase):
msg = "populate_dir is not implemented for public vars stores" msg = "populate_dir is not implemented for public vars stores"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
msg = "upload is not implemented for public vars stores" msg = "upload is not implemented for public vars stores"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -4,11 +4,11 @@ from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
from clan_cli.vars.generate import Generator, Var from clan_cli.vars.generate import Generator, Var
from clan_lib.dirs import vm_state_dir from clan_lib.dirs import vm_state_dir
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -70,6 +70,6 @@ class FactStore(StoreBase):
msg = "populate_dir is not implemented for public vars stores" msg = "populate_dir is not implemented for public vars stores"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
msg = "upload is not implemented for public vars stores" msg = "upload is not implemented for public vars stores"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -3,9 +3,9 @@ import tempfile
from pathlib import Path from pathlib import Path
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
from clan_cli.vars.generate import Generator, Var from clan_cli.vars.generate import Generator, Var
from clan_lib.ssh.remote import Remote
class SecretStore(StoreBase): class SecretStore(StoreBase):
@@ -46,6 +46,6 @@ class SecretStore(StoreBase):
shutil.copytree(self.dir, output_dir) shutil.copytree(self.dir, output_dir)
shutil.rmtree(self.dir) shutil.rmtree(self.dir)
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
msg = "Cannot upload secrets with FS backend" msg = "Cannot upload secrets with FS backend"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -8,12 +8,12 @@ from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_cli.ssh.upload import upload from clan_cli.ssh.upload import upload
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
from clan_cli.vars.generate import Generator, Var from clan_cli.vars.generate import Generator, Var
from clan_lib.cmd import CmdOut, Log, RunOpts, run from clan_lib.cmd import CmdOut, Log, RunOpts, run
from clan_lib.nix import nix_shell from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import Remote
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -147,16 +147,17 @@ class SecretStore(StoreBase):
manifest += hashes manifest += hashes
return b"\n".join(manifest) return b"\n".join(manifest)
def needs_upload(self, host: Host) -> bool: def needs_upload(self, host: Remote) -> bool:
local_hash = self.generate_hash() local_hash = self.generate_hash()
remote_hash = host.run( with host.ssh_control_master() as ssh:
# TODO get the path to the secrets from the machine remote_hash = ssh.run(
[ # TODO get the path to the secrets from the machine
"cat", [
f"{self.machine.deployment['password-store']['secretLocation']}/.{self._store_backend}_info", "cat",
], f"{self.machine.deployment['password-store']['secretLocation']}/.{self._store_backend}_info",
RunOpts(log=Log.STDERR, check=False), ],
).stdout.strip() RunOpts(log=Log.STDERR, check=False),
).stdout.strip()
if not remote_hash: if not remote_hash:
print("remote hash is empty") print("remote hash is empty")
@@ -226,7 +227,7 @@ class SecretStore(StoreBase):
(output_dir / f".{self._store_backend}_info").write_bytes(self.generate_hash()) (output_dir / f".{self._store_backend}_info").write_bytes(self.generate_hash())
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
if "partitioning" in phases: if "partitioning" in phases:
msg = "Cannot upload partitioning secrets" msg = "Cannot upload partitioning secrets"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -22,12 +22,12 @@ from clan_cli.secrets.secrets import (
groups_folder, groups_folder,
has_secret, has_secret,
) )
from clan_cli.ssh.host import Host
from clan_cli.ssh.upload import upload from clan_cli.ssh.upload import upload
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
from clan_cli.vars.generate import Generator from clan_cli.vars.generate import Generator
from clan_cli.vars.var import Var from clan_cli.vars.var import Var
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
@dataclass @dataclass
@@ -224,7 +224,7 @@ class SecretStore(StoreBase):
target_path.chmod(file.mode) target_path.chmod(file.mode)
@override @override
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
if "partitioning" in phases: if "partitioning" in phases:
msg = "Cannot upload partitioning secrets" msg = "Cannot upload partitioning secrets"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -3,10 +3,10 @@ from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_cli.vars._types import StoreBase from clan_cli.vars._types import StoreBase
from clan_cli.vars.generate import Generator, Var from clan_cli.vars.generate import Generator, Var
from clan_lib.dirs import vm_state_dir from clan_lib.dirs import vm_state_dir
from clan_lib.ssh.remote import Remote
class SecretStore(StoreBase): class SecretStore(StoreBase):
@@ -61,6 +61,6 @@ class SecretStore(StoreBase):
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
shutil.copytree(self.dir, output_dir) shutil.copytree(self.dir, output_dir)
def upload(self, host: Host, phases: list[str]) -> None: def upload(self, host: Remote, phases: list[str]) -> None:
msg = "Cannot upload secrets to VMs" msg = "Cannot upload secrets to VMs"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@@ -4,12 +4,12 @@ from pathlib import Path
from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host from clan_lib.ssh.remote import Remote
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def upload_secret_vars(machine: Machine, host: Host) -> None: def upload_secret_vars(machine: Machine, host: Remote) -> None:
machine.secret_vars_store.upload(host, phases=["activation", "users", "services"]) machine.secret_vars_store.upload(host, phases=["activation", "users", "services"])
@@ -28,8 +28,8 @@ def upload_command(args: argparse.Namespace) -> None:
populate_secret_vars(machine, directory) populate_secret_vars(machine, directory)
return return
with machine.target_host() as host: host = machine.target_host()
upload_secret_vars(machine, host) upload_secret_vars(machine, host)
def register_upload_parser(parser: argparse.ArgumentParser) -> None: def register_upload_parser(parser: argparse.ArgumentParser) -> None:

View File

@@ -30,8 +30,9 @@ def check_machine_online(
timeout = opts.timeout if opts and opts.timeout else 2 timeout = opts.timeout if opts and opts.timeout else 2
for _ in range(opts.retries if opts and opts.retries else 10): for _ in range(opts.retries if opts and opts.retries else 10):
with machine.target_host() as target: host = machine.target_host()
res = target.run( with host.ssh_control_master() as ssh:
res = ssh.run(
["true"], ["true"],
RunOpts(timeout=timeout, check=False, needs_user_terminal=True), RunOpts(timeout=timeout, check=False, needs_user_terminal=True),
) )

View File

@@ -6,13 +6,14 @@ from clan_lib.errors import ClanError
def create_backup(machine: Machine, provider: str | None = None) -> None: def create_backup(machine: Machine, provider: str | None = None) -> None:
machine.info(f"creating backup for {machine.name}") machine.info(f"creating backup for {machine.name}")
backup_scripts = machine.eval_nix("config.clan.core.backups") backup_scripts = machine.eval_nix("config.clan.core.backups")
host = machine.target_host()
if provider is None: if provider is None:
if not backup_scripts["providers"]: if not backup_scripts["providers"]:
msg = "No providers specified" msg = "No providers specified"
raise ClanError(msg) raise ClanError(msg)
with machine.target_host() as host: with host.ssh_control_master() as ssh:
for provider in backup_scripts["providers"]: for provider in backup_scripts["providers"]:
proc = host.run( proc = ssh.run(
[backup_scripts["providers"][provider]["create"]], [backup_scripts["providers"][provider]["create"]],
) )
if proc.returncode != 0: if proc.returncode != 0:
@@ -23,8 +24,8 @@ def create_backup(machine: Machine, provider: str | None = None) -> None:
if provider not in backup_scripts["providers"]: if provider not in backup_scripts["providers"]:
msg = f"provider {provider} not found" msg = f"provider {provider} not found"
raise ClanError(msg) raise ClanError(msg)
with machine.target_host() as host: with host.ssh_control_master() as ssh:
proc = host.run( proc = ssh.run(
[backup_scripts["providers"][provider]["create"]], [backup_scripts["providers"][provider]["create"]],
) )
if proc.returncode != 0: if proc.returncode != 0:

View File

@@ -2,10 +2,10 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_lib.cmd import Log, RunOpts from clan_lib.cmd import Log, RunOpts
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
@dataclass @dataclass
@@ -14,14 +14,15 @@ class Backup:
job_name: str | None = None job_name: str | None = None
def list_provider(machine: Machine, host: Host, provider: str) -> list[Backup]: def list_provider(machine: Machine, host: Remote, provider: str) -> list[Backup]:
results = [] results = []
backup_metadata = machine.eval_nix("config.clan.core.backups") backup_metadata = machine.eval_nix("config.clan.core.backups")
list_command = backup_metadata["providers"][provider]["list"] list_command = backup_metadata["providers"][provider]["list"]
proc = host.run( with host.ssh_control_master() as ssh:
[list_command], proc = ssh.run(
RunOpts(log=Log.NONE, check=False), [list_command],
) RunOpts(log=Log.NONE, check=False),
)
if proc.returncode != 0: if proc.returncode != 0:
# TODO this should be a warning, only raise exception if no providers succeed # TODO this should be a warning, only raise exception if no providers succeed
msg = f"Failed to list backups for provider {provider}:" msg = f"Failed to list backups for provider {provider}:"
@@ -44,12 +45,12 @@ def list_provider(machine: Machine, host: Host, provider: str) -> list[Backup]:
def list_backups(machine: Machine, provider: str | None = None) -> list[Backup]: def list_backups(machine: Machine, provider: str | None = None) -> list[Backup]:
backup_metadata = machine.eval_nix("config.clan.core.backups") backup_metadata = machine.eval_nix("config.clan.core.backups")
results = [] results = []
with machine.target_host() as host: host = machine.target_host()
if provider is None: if provider is None:
for _provider in backup_metadata["providers"]: for _provider in backup_metadata["providers"]:
results += list_provider(machine, host, _provider) results += list_provider(machine, host, _provider)
else: else:
results += list_provider(machine, host, provider) results += list_provider(machine, host, provider)
return results return results

View File

@@ -1,12 +1,12 @@
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.ssh.host import Host
from clan_lib.cmd import Log, RunOpts from clan_lib.cmd import Log, RunOpts
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.ssh.remote import Remote
def restore_service( def restore_service(
machine: Machine, host: Host, name: str, provider: str, service: str machine: Machine, host: Remote, name: str, provider: str, service: str
) -> None: ) -> None:
backup_metadata = machine.eval_nix("config.clan.core.backups") backup_metadata = machine.eval_nix("config.clan.core.backups")
backup_folders = machine.eval_nix("config.clan.core.state") backup_folders = machine.eval_nix("config.clan.core.state")
@@ -21,34 +21,35 @@ def restore_service(
# FIXME: If we have too many folder this might overflow the stack. # FIXME: If we have too many folder this might overflow the stack.
env["FOLDERS"] = ":".join(set(folders)) env["FOLDERS"] = ":".join(set(folders))
if pre_restore := backup_folders[service]["preRestoreCommand"]: with host.ssh_control_master() as ssh:
proc = host.run( if pre_restore := backup_folders[service]["preRestoreCommand"]:
[pre_restore], proc = ssh.run(
[pre_restore],
RunOpts(log=Log.STDERR),
extra_env=env,
)
if proc.returncode != 0:
msg = f"failed to run preRestoreCommand: {pre_restore}, error was: {proc.stdout}"
raise ClanError(msg)
proc = ssh.run(
[backup_metadata["providers"][provider]["restore"]],
RunOpts(log=Log.STDERR), RunOpts(log=Log.STDERR),
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:
msg = f"failed to run preRestoreCommand: {pre_restore}, error was: {proc.stdout}" msg = f"failed to restore backup: {backup_metadata['providers'][provider]['restore']}"
raise ClanError(msg) raise ClanError(msg)
proc = host.run( if post_restore := backup_folders[service]["postRestoreCommand"]:
[backup_metadata["providers"][provider]["restore"]], proc = ssh.run(
RunOpts(log=Log.STDERR), [post_restore],
extra_env=env, RunOpts(log=Log.STDERR),
) extra_env=env,
if proc.returncode != 0: )
msg = f"failed to restore backup: {backup_metadata['providers'][provider]['restore']}" if proc.returncode != 0:
raise ClanError(msg) msg = f"failed to run postRestoreCommand: {post_restore}, error was: {proc.stdout}"
raise ClanError(msg)
if post_restore := backup_folders[service]["postRestoreCommand"]:
proc = host.run(
[post_restore],
RunOpts(log=Log.STDERR),
extra_env=env,
)
if proc.returncode != 0:
msg = f"failed to run postRestoreCommand: {post_restore}, error was: {proc.stdout}"
raise ClanError(msg)
def restore_backup( def restore_backup(
@@ -58,7 +59,8 @@ def restore_backup(
service: str | None = None, service: str | None = None,
) -> None: ) -> None:
errors = [] errors = []
with machine.target_host() as host: host = machine.target_host()
with host.ssh_control_master():
if service is None: if service is None:
backup_folders = machine.eval_nix("config.clan.core.state") backup_folders = machine.eval_nix("config.clan.core.state")
for _service in backup_folders: for _service in backup_folders:

View File

View File

@@ -1,23 +1,26 @@
import re import re
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any
from clan_cli.ssh.host_key import HostKeyCheck
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_cli.ssh.host import Host if TYPE_CHECKING:
from clan_cli.ssh.host_key import HostKeyCheck from clan_lib.ssh.remote import Remote
def parse_deployment_address( def parse_deployment_address(
*,
machine_name: str, machine_name: str,
host: str, address: str,
host_key_check: HostKeyCheck, host_key_check: HostKeyCheck,
forward_agent: bool = True, forward_agent: bool = True,
meta: dict[str, Any] | None = None, meta: dict[str, Any] | None = None,
private_key: Path | None = None, private_key: Path | None = None,
) -> Host: ) -> "Remote":
parts = host.split("?", maxsplit=1) parts = address.split("?", maxsplit=1)
endpoint, maybe_options = parts if len(parts) == 2 else (parts[0], "") endpoint, maybe_options = parts if len(parts) == 2 else (parts[0], "")
parts = endpoint.split("@") parts = endpoint.split("@")
@@ -25,15 +28,15 @@ def parse_deployment_address(
case 2: case 2:
user, host_port = parts user, host_port = parts
case 1: case 1:
user, host_port = "", parts[0] user, host_port = "root", parts[0]
case _: case _:
msg = f"Invalid host, got `{host}` but expected something like `[user@]hostname[:port]`" msg = f"Invalid host, got `{address}` but expected something like `[user@]hostname[:port]`"
raise ClanError(msg) raise ClanError(msg)
# Make this check now rather than failing with a `ValueError` # Make this check now rather than failing with a `ValueError`
# when looking up the port from the `urlsplit` result below: # when looking up the port from the `urlsplit` result below:
if host_port.count(":") > 1 and not re.match(r".*\[.*]", host_port): if host_port.count(":") > 1 and not re.match(r".*\[.*]", host_port):
msg = f"Invalid hostname: {host}. IPv6 addresses must be enclosed in brackets , e.g. [::1]" msg = f"Invalid hostname: {address}. IPv6 addresses must be enclosed in brackets , e.g. [::1]"
raise ClanError(msg) raise ClanError(msg)
options: dict[str, str] = {} options: dict[str, str] = {}
@@ -43,7 +46,7 @@ def parse_deployment_address(
parts = o.split("=", maxsplit=1) parts = o.split("=", maxsplit=1)
if len(parts) != 2: if len(parts) != 2:
msg = ( msg = (
f"Invalid option in host `{host}`: option `{o}` does not have " f"Invalid option in host `{address}`: option `{o}` does not have "
f"a value (i.e. expected something like `name=value`)" f"a value (i.e. expected something like `name=value`)"
) )
raise ClanError(msg) raise ClanError(msg)
@@ -52,19 +55,19 @@ def parse_deployment_address(
result = urllib.parse.urlsplit(f"//{host_port}") result = urllib.parse.urlsplit(f"//{host_port}")
if not result.hostname: if not result.hostname:
msg = f"Invalid host, got `{host}` but expected something like `[user@]hostname[:port]`" msg = f"Invalid host, got `{address}` but expected something like `[user@]hostname[:port]`"
raise ClanError(msg) raise ClanError(msg)
hostname = result.hostname hostname = result.hostname
port = result.port port = result.port
from clan_lib.ssh.remote import Remote
return Host( return Remote(
hostname, address=hostname,
user=user, user=user,
port=port, port=port,
private_key=private_key, private_key=private_key,
host_key_check=host_key_check, host_key_check=host_key_check,
command_prefix=machine_name, command_prefix=machine_name,
forward_agent=forward_agent, forward_agent=forward_agent,
meta={} if meta is None else meta.copy(),
ssh_options=options, ssh_options=options,
) )

View File

@@ -1,98 +1,92 @@
# Adapted from https://github.com/numtide/deploykit # ruff: noqa: SLF001
import logging import logging
import os import os
import shlex import shlex
import socket import socket
import subprocess import subprocess
import sys import sys
import types from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from shlex import quote from shlex import quote
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any
from clan_lib.cmd import CmdOut, RunOpts, run
from clan_lib.colors import AnsiColor
from clan_lib.errors import ClanError
from clan_lib.nix import nix_shell
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
cmdlog = logging.getLogger(__name__) from clan_lib.cmd import CmdOut, RunOpts, run
from clan_lib.colors import AnsiColor
from clan_lib.errors import ClanError # Assuming these are available
from clan_lib.nix import nix_shell
from clan_lib.ssh.parse import parse_deployment_address
cmdlog = logging.getLogger(__name__)
# Seconds until a message is printed when _run produces no output. # Seconds until a message is printed when _run produces no output.
NO_OUTPUT_TIMEOUT = 20 NO_OUTPUT_TIMEOUT = 20
@dataclass @dataclass(frozen=True)
class Host: class Remote:
host: str address: str
user: str | None = None user: str
command_prefix: str
port: int | None = None port: int | None = None
private_key: Path | None = None private_key: Path | None = None
password: str | None = None password: str | None = None
forward_agent: bool = False forward_agent: bool = True
command_prefix: str | None = None
host_key_check: HostKeyCheck = HostKeyCheck.ASK host_key_check: HostKeyCheck = HostKeyCheck.ASK
meta: dict[str, Any] = field(default_factory=dict)
verbose_ssh: bool = False verbose_ssh: bool = False
ssh_options: dict[str, str] = field(default_factory=dict) ssh_options: dict[str, str] = field(default_factory=dict)
tor_socks: bool = False tor_socks: bool = False
_control_path_dir: Path | None = None
_temp_dir: TemporaryDirectory | None = None
def __enter__(self) -> "Host":
directory = None
if sys.platform == "darwin" and os.environ.get("TMPDIR", "").startswith(
"/var/folders/"
):
# macOS's tmpdir is too long for unix domain sockets
directory = "/tmp/"
self._temp_dir = TemporaryDirectory(prefix="clan-ssh-", dir=directory)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
try:
if self._temp_dir:
self._temp_dir.cleanup()
except OSError:
pass
def __post_init__(self) -> None:
if not self.command_prefix:
self.command_prefix = self.host
if not self.user:
self.user = "root"
def __str__(self) -> str: def __str__(self) -> str:
return self.target return self.target
@property @property
def target(self) -> str: def target(self) -> str:
return f"{self.user}@{self.host}" return f"{self.user}@{self.address}"
@classmethod @classmethod
def from_host(cls, host: "Host") -> "Host": def with_user(cls, host: "Remote", user: str) -> "Remote":
"""
Return a new Remote object with the specified user.
"""
return cls( return cls(
host=host.host, address=host.address,
user=host.user, user=user,
command_prefix=host.command_prefix,
port=host.port, port=host.port,
private_key=host.private_key, private_key=host.private_key,
password=host.password,
forward_agent=host.forward_agent, forward_agent=host.forward_agent,
command_prefix=host.command_prefix,
host_key_check=host.host_key_check, host_key_check=host.host_key_check,
meta=host.meta.copy(),
verbose_ssh=host.verbose_ssh, verbose_ssh=host.verbose_ssh,
ssh_options=host.ssh_options.copy(), ssh_options=host.ssh_options,
tor_socks=host.tor_socks,
)
@classmethod
def from_deployment_address(
cls,
*,
machine_name: str,
address: str,
host_key_check: HostKeyCheck,
forward_agent: bool = True,
private_key: Path | None = None,
) -> "Remote":
"""
Parse a deployment address and return a Host object.
"""
return parse_deployment_address(
machine_name=machine_name,
address=address,
host_key_check=host_key_check,
forward_agent=forward_agent,
private_key=private_key,
) )
def run_local( def run_local(
@@ -109,7 +103,6 @@ class Host:
env = opts.env or os.environ.copy() env = opts.env or os.environ.copy()
if extra_env: if extra_env:
env.update(extra_env) env.update(extra_env)
displayed_cmd = " ".join(cmd) displayed_cmd = " ".join(cmd)
cmdlog.info( cmdlog.info(
f"$ {displayed_cmd}", f"$ {displayed_cmd}",
@@ -122,6 +115,36 @@ class Host:
opts.prefix = self.command_prefix opts.prefix = self.command_prefix
return run(cmd, opts) return run(cmd, opts)
@contextmanager
def ssh_control_master(self) -> Iterator["Remote"]:
"""
Context manager to manage SSH ControlMaster connections.
This will create a temporary directory for the control socket.
"""
directory = None
if sys.platform == "darwin" and os.environ.get("TMPDIR", "").startswith(
"/var/folders/"
):
directory = "/tmp/"
# Use more specific prefix for the temp dir to avoid potential collisions if multiple hosts used
prefix = f"clan-ssh-{self.address}-{self.port or 22}-{self.user}-"
temp_dir = TemporaryDirectory(prefix=prefix, dir=directory)
yield Remote(
address=self.address,
user=self.user,
command_prefix=self.command_prefix,
port=self.port,
private_key=self.private_key,
password=self.password,
forward_agent=self.forward_agent,
host_key_check=self.host_key_check,
verbose_ssh=self.verbose_ssh,
ssh_options=self.ssh_options,
tor_socks=self.tor_socks,
_control_path_dir=Path(temp_dir.name),
)
temp_dir.cleanup()
def run( def run(
self, self,
cmd: list[str], cmd: list[str],
@@ -131,36 +154,32 @@ class Host:
tty: bool = False, tty: bool = False,
verbose_ssh: bool = False, verbose_ssh: bool = False,
quiet: bool = False, quiet: bool = False,
control_master: bool = True,
) -> CmdOut: ) -> CmdOut:
""" """
Command to run on the host via ssh Internal method to run a command on the host via ssh.
`control_path_dir`: If provided, SSH ControlMaster options will be used.
""" """
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
if opts is None: if opts is None:
opts = RunOpts() opts = RunOpts()
# Quote all added environment variables sudo = ""
if become_root and self.user != "root":
sudo = "sudo -- "
env_vars = [] env_vars = []
for k, v in extra_env.items(): for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}") env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
sudo = []
if become_root and self.user != "root":
# If we are not root and we need to become root, prepend sudo
sudo = ["sudo", "--"]
if opts.prefix is None: if opts.prefix is None:
opts.prefix = self.command_prefix opts.prefix = self.command_prefix
# always set needs_user_terminal to True because ssh asks for passwords
opts.needs_user_terminal = True opts.needs_user_terminal = True
if opts.cwd is not None: if opts.cwd is not None:
msg = "cwd is not supported for remote commands" msg = "cwd is not supported for remote commands"
raise ClanError(msg) raise ClanError(msg)
# Build a pretty command for logging
displayed_cmd = "" displayed_cmd = ""
export_cmd = "" export_cmd = ""
if env_vars: if env_vars:
@@ -177,111 +196,104 @@ class Host:
}, },
) )
# Build the ssh command
bash_cmd = export_cmd bash_cmd = export_cmd
if opts.shell: if opts.shell:
bash_cmd += " ".join(cmd) bash_cmd += " ".join(cmd)
opts.shell = False opts.shell = False
else: else:
bash_cmd += 'exec "$@"' bash_cmd += 'exec "$@"'
# FIXME we assume bash to be present here? Should be documented...
ssh_cmd = [
*self.ssh_cmd(verbose_ssh=verbose_ssh, tty=tty),
"--",
*sudo,
"bash",
"-c",
quote(bash_cmd),
"--",
" ".join(map(quote, cmd)),
]
# Run the ssh command ssh_cmd_list = self.ssh_cmd(
return run(ssh_cmd, opts) verbose_ssh=verbose_ssh, tty=tty, control_master=control_master
)
ssh_cmd_list.extend(
["--", f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}"]
)
return run(ssh_cmd_list, opts)
def nix_ssh_env( def nix_ssh_env(
self, env: dict[str, str] | None, local_ssh: bool = True self,
env: dict[str, str] | None = None,
control_master: bool = True,
) -> dict[str, str]: ) -> dict[str, str]:
if env is None: if env is None:
env = {} env = {}
env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts(local_ssh=local_ssh)) env["NIX_SSHOPTS"] = " ".join(
self.ssh_cmd_opts(control_master=control_master) # Renamed
)
return env return env
def ssh_cmd_opts( def ssh_cmd_opts(
self, self,
local_ssh: bool = True, control_master: bool = True,
) -> list[str]: ) -> list[str]:
effective_control_path_dir = self._control_path_dir
if self._control_path_dir is None and not control_master:
effective_control_path_dir = None
elif self._control_path_dir is None and control_master:
msg = "Control path directory is not set. Please with Remote.ssh_control_master() as ctx to set it."
raise ClanError(msg)
ssh_opts = ["-A"] if self.forward_agent else [] ssh_opts = ["-A"] if self.forward_agent else []
if self.port: if self.port:
ssh_opts.extend(["-p", str(self.port)]) ssh_opts.extend(["-p", str(self.port)])
for k, v in self.ssh_options.items(): for k, v in self.ssh_options.items():
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"]) ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
ssh_opts.extend(self.host_key_check.to_ssh_opt()) ssh_opts.extend(self.host_key_check.to_ssh_opt())
if self.private_key: if self.private_key:
ssh_opts.extend(["-i", str(self.private_key)]) ssh_opts.extend(["-i", str(self.private_key)])
if local_ssh and self._temp_dir: if effective_control_path_dir:
ssh_opts.extend(["-o", "ControlPersist=30m"]) socket_path = (
ssh_opts.extend( effective_control_path_dir
[ / f"clan-{self.address}-{self.port or 22}-{self.user}"
"-o",
f"ControlPath={Path(self._temp_dir.name) / 'clan-%h-%p-%r'}",
]
) )
ssh_opts.extend(["-o", "ControlPersist=30m"])
ssh_opts.extend(["-o", f"ControlPath={socket_path}"])
ssh_opts.extend(["-o", "ControlMaster=auto"]) ssh_opts.extend(["-o", "ControlMaster=auto"])
return ssh_opts return ssh_opts
def ssh_cmd( def ssh_cmd(
self, self, verbose_ssh: bool = False, tty: bool = False, control_master: bool = True
verbose_ssh: bool = False,
tty: bool = False,
) -> list[str]: ) -> list[str]:
packages = [] packages = []
password_args = [] password_args = []
if self.password: if self.password:
packages.append("sshpass") packages.append("sshpass")
password_args = [ password_args = ["sshpass", "-p", self.password]
"sshpass",
"-p",
self.password,
]
ssh_opts = self.ssh_cmd_opts() current_ssh_opts = self.ssh_cmd_opts(control_master=control_master)
if verbose_ssh or self.verbose_ssh: if verbose_ssh or self.verbose_ssh:
ssh_opts.extend(["-v"]) current_ssh_opts.extend(["-v"])
if tty: if tty:
ssh_opts.extend(["-t"]) current_ssh_opts.extend(["-t"])
if self.tor_socks: if self.tor_socks:
packages.append("netcat") packages.append("netcat")
ssh_opts.append("-o") current_ssh_opts.extend(
ssh_opts.append("ProxyCommand=nc -x 127.0.0.1:9050 -X 5 %h %p") ["-o", "ProxyCommand=nc -x 127.0.0.1:9050 -X 5 %h %p"]
)
cmd = [ cmd = [
*password_args, *password_args,
"ssh", "ssh",
self.target, self.target,
*ssh_opts, *current_ssh_opts,
] ]
return nix_shell(packages, cmd) return nix_shell(packages, cmd)
def interactive_ssh(self) -> None: def interactive_ssh(self) -> None:
subprocess.run(self.ssh_cmd()) cmd_list = self.ssh_cmd(tty=True)
subprocess.run(cmd_list)
def is_ssh_reachable(host: Host) -> bool: def is_ssh_reachable(host: Remote) -> bool:
with socket.socket( address_family = socket.AF_INET6 if ":" in host.address else socket.AF_INET
socket.AF_INET6 if ":" in host.host else socket.AF_INET, socket.SOCK_STREAM with socket.socket(address_family, socket.SOCK_STREAM) as sock:
) as sock:
sock.settimeout(2) sock.settimeout(2)
try: try:
sock.connect((host.host, host.port or 22)) sock.connect((host.address, host.port or 22))
sock.close()
except OSError: except OSError:
return False return False
else: else:

View File

@@ -14,7 +14,6 @@ from clan_cli.machines.machines import Machine
from clan_cli.secrets.key import generate_key from clan_cli.secrets.key import generate_key
from clan_cli.secrets.sops import maybe_get_admin_public_key from clan_cli.secrets.sops import maybe_get_admin_public_key
from clan_cli.secrets.users import add_user from clan_cli.secrets.users import add_user
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.vars.generate import generate_vars_for_machine, get_generators_closure from clan_cli.vars.generate import generate_vars_for_machine, get_generators_closure
@@ -28,6 +27,7 @@ from clan_lib.inventory import patch_inventory_with
from clan_lib.nix import nix_command from clan_lib.nix import nix_command
from clan_lib.nix_models.inventory import Machine as InventoryMachine from clan_lib.nix_models.inventory import Machine as InventoryMachine
from clan_lib.nix_models.inventory import MachineDeploy from clan_lib.nix_models.inventory import MachineDeploy
from clan_lib.ssh.remote import Remote
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -118,9 +118,9 @@ def fix_flake_inputs(clan_dir: Path, clan_core_dir: Path) -> None:
@pytest.mark.with_core @pytest.mark.with_core
@pytest.mark.skipif(sys.platform == "darwin", reason="sshd fails to start on darwin") @pytest.mark.skipif(sys.platform == "darwin", reason="sshd fails to start on darwin")
def test_clan_create_api( def test_clan_create_api(
temporary_home: Path, test_lib_root: Path, clan_core: Path, hosts: list[Host] temporary_home: Path, test_lib_root: Path, clan_core: Path, hosts: list[Remote]
) -> None: ) -> None:
host_ip = hosts[0].host host_ip = hosts[0].address
host_user = hosts[0].user host_user = hosts[0].user
vm_name = "test-clan" vm_name = "test-clan"
clan_core_dir_var = str(clan_core) clan_core_dir_var = str(clan_core)
@@ -176,7 +176,9 @@ def test_clan_create_api(
clan_dir_flake = Flake(str(dest_clan_dir)) clan_dir_flake = Flake(str(dest_clan_dir))
machines: list[Machine] = [] machines: list[Machine] = []
host = Host(user=host_user, host=host_ip, port=int(ssh_port_var)) host = Remote(
user=host_user, address=host_ip, port=int(ssh_port_var), command_prefix=vm_name
)
# TODO: We need to merge Host and Machine class these duplicate targetHost stuff is a nightmare # TODO: We need to merge Host and Machine class these duplicate targetHost stuff is a nightmare
inv_machine = InventoryMachine( inv_machine = InventoryMachine(
name=vm_name, deploy=MachineDeploy(targetHost=f"{host.target}:{ssh_port_var}") name=vm_name, deploy=MachineDeploy(targetHost=f"{host.target}:{ssh_port_var}")