clan-lib: Refactor remote host handling to function parameters

This refactoring improves the separation of concerns by moving remote host creation logic from the Machine class to the calling functions, making the code more flexible and testable.
This commit is contained in:
Qubasa
2025-06-17 11:15:10 +02:00
committed by Jörg Thalheim
parent 2af619609a
commit 6a3bffad42
5 changed files with 98 additions and 64 deletions

View File

@@ -12,6 +12,7 @@ from clan_lib.errors import ClanCmdError, ClanError
from clan_lib.git import commit_file
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_config, nix_eval
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_cli.completions import add_dynamic_completer, complete_machines
@@ -82,7 +83,9 @@ class HardwareGenerateOptions:
@API.register
def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareConfig:
def generate_machine_hardware_info(
opts: HardwareGenerateOptions, target_host: Remote
) -> HardwareConfig:
"""
Generate hardware information for a machine
and place the resulting *.nix file in the machine's directory.
@@ -103,9 +106,7 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
"--show-hardware-config",
]
host = opts.machine.target_host()
with host.ssh_control_master() as ssh, ssh.become_root() as sudo_ssh:
with target_host.ssh_control_master() as ssh, ssh.become_root() as sudo_ssh:
out = sudo_ssh.run(config_command, opts=RunOpts(check=False))
if out.returncode != 0:
if "nixos-facter" in out.stderr and "not found" in out.stderr:
@@ -117,7 +118,7 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
raise ClanError(msg)
machine.error(str(out))
msg = f"Failed to inspect {opts.machine}. Address: {host.target}"
msg = f"Failed to inspect {opts.machine}. Address: {target_host.target}"
raise ClanError(msg)
backup_file = None
@@ -157,17 +158,28 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
def update_hardware_config_command(args: argparse.Namespace) -> None:
host_key_check = HostKeyCheck.from_str(args.host_key_check)
machine = Machine(
flake=args.flake,
name=args.machine,
override_target_host=args.target_host,
host_key_check=host_key_check,
)
opts = HardwareGenerateOptions(
machine=machine,
password=args.password,
backend=HardwareConfig(args.backend),
)
generate_machine_hardware_info(opts)
if args.target_host:
target_host = Remote.from_deployment_address(
machine_name=machine.name,
address=args.target_host,
host_key_check=host_key_check,
)
else:
target_host = machine.target_host()
generate_machine_hardware_info(opts, target_host)
def register_update_hardware_config(parser: argparse.ArgumentParser) -> None:
@@ -184,6 +196,12 @@ def register_update_hardware_config(parser: argparse.ArgumentParser) -> None:
nargs="?",
help="ssh address to install to in the form of user@host:2222",
)
parser.add_argument(
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="ask",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.add_argument(
"--password",
help="Pre-provided password the cli will prompt otherwise if needed.",

View File

@@ -12,6 +12,7 @@ from clan_lib.cmd import Log, RunOpts, run
from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_cli.completions import (
add_dynamic_completer,
@@ -48,7 +49,7 @@ class InstallOptions:
@API.register
def install_machine(opts: InstallOptions) -> None:
def install_machine(opts: InstallOptions, target_host: Remote) -> None:
machine = opts.machine
machine.debug(f"installing {machine.name}")
@@ -56,7 +57,6 @@ def install_machine(opts: InstallOptions) -> None:
generate_facts([machine])
generate_vars([machine])
host = machine.target_host()
with (
TemporaryDirectory(prefix="nixos-install-") as _base_directory,
):
@@ -127,8 +127,8 @@ def install_machine(opts: InstallOptions) -> None:
if opts.build_on:
cmd += ["--build-on", opts.build_on.value]
if host.port:
cmd += ["--ssh-port", str(host.port)]
if target_host.port:
cmd += ["--ssh-port", str(target_host.port)]
if opts.kexec:
cmd += ["--kexec", opts.kexec]
@@ -138,7 +138,7 @@ def install_machine(opts: InstallOptions) -> None:
# Add nix options to nixos-anywhere
cmd.extend(opts.nix_options)
cmd.append(host.target)
cmd.append(target_host.target)
if opts.use_tor:
# nix copy does not support tor socks proxy
# cmd.append("--ssh-option")
@@ -162,7 +162,7 @@ def install_command(args: argparse.Namespace) -> None:
try:
# Only if the caller did not specify a target_host via args.target_host
# Find a suitable target_host that is reachable
target_host = args.target_host
target_host_str = args.target_host
deploy_info: DeployInfo | None = ssh_command_parse(args)
use_tor = False
@@ -170,9 +170,9 @@ def install_command(args: argparse.Namespace) -> None:
host = find_reachable_host(deploy_info)
if host is None:
use_tor = True
target_host = deploy_info.tor.target
target_host_str = deploy_info.tor.target
else:
target_host = host.target
target_host_str = host.target
if args.password:
password = args.password
@@ -181,12 +181,20 @@ def install_command(args: argparse.Namespace) -> None:
else:
password = None
machine = Machine(
name=args.machine,
flake=args.flake,
nix_options=args.option,
override_target_host=target_host,
machine = Machine(name=args.machine, flake=args.flake, nix_options=args.option)
host_key_check = (
HostKeyCheck.from_str(args.host_key_check)
if args.host_key_check
else HostKeyCheck.ASK
)
if target_host_str is not None:
target_host = Remote.from_deployment_address(
machine_name=machine.name,
address=target_host_str,
host_key_check=host_key_check,
)
else:
target_host = machine.target_host().with_data(host_key_check=host_key_check)
if machine._class_ == "darwin":
msg = "Installing macOS machines is not yet supported"
@@ -217,6 +225,7 @@ def install_command(args: argparse.Namespace) -> None:
identity_file=args.identity_file,
use_tor=use_tor,
),
target_host=target_host,
)
except KeyboardInterrupt:
log.warning("Interrupted by user")

View File

@@ -104,10 +104,12 @@ def upload_sources(machine: Machine, ssh: Remote) -> str:
@API.register
def deploy_machine(machine: Machine) -> None:
def deploy_machine(
machine: Machine, target_host: Remote, build_host: Remote | None
) -> None:
with ExitStack() as stack:
target_host = stack.enter_context(machine.target_host().ssh_control_master())
build_host = machine.build_host()
target_host = stack.enter_context(target_host.ssh_control_master())
if build_host is not None:
build_host = stack.enter_context(build_host.ssh_control_master())
@@ -198,24 +200,6 @@ def deploy_machine(machine: Machine) -> None:
)
def deploy_machines(machines: list[Machine]) -> None:
"""
Deploy to all hosts in parallel
"""
with AsyncRuntime() as runtime:
for machine in machines:
runtime.async_run(
AsyncOpts(
tid=machine.name, async_ctx=AsyncContext(prefix=machine.name)
),
deploy_machine,
machine,
)
runtime.join_all()
runtime.check_all()
def update_command(args: argparse.Namespace) -> None:
try:
if args.flake is None:
@@ -237,8 +221,6 @@ def update_command(args: argparse.Namespace) -> None:
name=machine_name,
flake=args.flake,
nix_options=args.option,
override_target_host=args.target_host,
override_build_host=args.build_host,
host_key_check=HostKeyCheck.from_str(args.host_key_check),
)
machines.append(machine)
@@ -285,8 +267,30 @@ def update_command(args: argparse.Namespace) -> None:
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.system.clan.deployment.file",
]
)
# Run the deplyoyment
deploy_machines(machines_to_update)
host_key_check = HostKeyCheck.from_str(args.host_key_check)
with AsyncRuntime() as runtime:
for machine in machines:
if args.target_host:
target_host = Remote.from_deployment_address(
machine_name=machine.name,
address=args.target_host,
host_key_check=host_key_check,
)
else:
target_host = machine.target_host()
runtime.async_run(
AsyncOpts(
tid=machine.name,
async_ctx=AsyncContext(prefix=machine.name),
),
deploy_machine,
machine=machine,
target_host=target_host,
build_host=machine.build_host(),
)
runtime.join_all()
runtime.check_all()
except KeyboardInterrupt:
log.warning("Interrupted by user")

View File

@@ -32,8 +32,6 @@ class Machine:
flake: Flake
nix_options: list[str] = field(default_factory=list)
override_target_host: None | str = None
override_build_host: None | str = None
private_key: Path | None = None
host_key_check: HostKeyCheck = HostKeyCheck.STRICT
@@ -143,14 +141,6 @@ class Machine:
return self.flake.path
def target_host(self) -> Remote:
if self.override_target_host:
return Remote.from_deployment_address(
machine_name=self.name,
address=self.override_target_host,
host_key_check=self.host_key_check,
private_key=self.private_key,
)
remote = get_host(self.name, self.flake, field="targetHost")
if remote is None:
msg = f"'targetHost' is not set for machine '{self.name}'"
@@ -178,15 +168,6 @@ class Machine:
The host where the machine is built and deployed from.
Can be the same as the target host.
"""
if self.override_build_host:
return Remote.from_deployment_address(
machine_name=self.name,
address=self.override_build_host,
host_key_check=self.host_key_check,
private_key=self.private_key,
)
remote = get_host(self.name, self.flake, field="buildHost")
if remote:

View File

@@ -54,6 +54,28 @@ class Remote:
except ValueError:
return False
def with_data(self, host_key_check: HostKeyCheck | None = None) -> "Remote":
"""
Returns a new Remote instance with the same data but with a different host_key_check.
"""
return Remote(
address=self.address,
user=self.user,
command_prefix=self.command_prefix,
port=self.port,
private_key=self.private_key,
password=self.password,
forward_agent=self.forward_agent,
host_key_check=host_key_check
if host_key_check is not None
else self.host_key_check,
verbose_ssh=self.verbose_ssh,
ssh_options=self.ssh_options,
tor_socks=self.tor_socks,
_control_path_dir=self._control_path_dir,
_askpass_path=self._askpass_path,
)
@property
def target(self) -> str:
return f"{self.user}@{self.address}"