clan-lib: Make Remote overridable over function arguments

This commit is contained in:
Qubasa
2025-06-17 11:15:10 +02:00
parent 520c0d78af
commit d29bba48e7
5 changed files with 98 additions and 64 deletions

View File

@@ -12,6 +12,7 @@ from clan_lib.errors import ClanCmdError, ClanError
from clan_lib.git import commit_file from clan_lib.git import commit_file
from clan_lib.machines.machines import Machine from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_config, nix_eval from clan_lib.nix import nix_config, nix_eval
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.completions import add_dynamic_completer, complete_machines
@@ -82,7 +83,9 @@ class HardwareGenerateOptions:
@API.register @API.register
def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareConfig: def generate_machine_hardware_info(
opts: HardwareGenerateOptions, target_host: Remote
) -> HardwareConfig:
""" """
Generate hardware information for a machine Generate hardware information for a machine
and place the resulting *.nix file in the machine's directory. and place the resulting *.nix file in the machine's directory.
@@ -103,9 +106,7 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
"--show-hardware-config", "--show-hardware-config",
] ]
host = opts.machine.target_host() with target_host.ssh_control_master() as ssh, ssh.become_root() as sudo_ssh:
with host.ssh_control_master() as ssh, ssh.become_root() as sudo_ssh:
out = sudo_ssh.run(config_command, opts=RunOpts(check=False)) out = sudo_ssh.run(config_command, opts=RunOpts(check=False))
if out.returncode != 0: if out.returncode != 0:
if "nixos-facter" in out.stderr and "not found" in out.stderr: if "nixos-facter" in out.stderr and "not found" in out.stderr:
@@ -117,7 +118,7 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
raise ClanError(msg) raise ClanError(msg)
machine.error(str(out)) machine.error(str(out))
msg = f"Failed to inspect {opts.machine}. Address: {host.target}" msg = f"Failed to inspect {opts.machine}. Address: {target_host.target}"
raise ClanError(msg) raise ClanError(msg)
backup_file = None backup_file = None
@@ -157,17 +158,28 @@ def generate_machine_hardware_info(opts: HardwareGenerateOptions) -> HardwareCon
def update_hardware_config_command(args: argparse.Namespace) -> None: def update_hardware_config_command(args: argparse.Namespace) -> None:
host_key_check = HostKeyCheck.from_str(args.host_key_check)
machine = Machine( machine = Machine(
flake=args.flake, flake=args.flake,
name=args.machine, name=args.machine,
override_target_host=args.target_host, host_key_check=host_key_check,
) )
opts = HardwareGenerateOptions( opts = HardwareGenerateOptions(
machine=machine, machine=machine,
password=args.password, password=args.password,
backend=HardwareConfig(args.backend), backend=HardwareConfig(args.backend),
) )
generate_machine_hardware_info(opts)
if args.target_host:
target_host = Remote.from_deployment_address(
machine_name=machine.name,
address=args.target_host,
host_key_check=host_key_check,
)
else:
target_host = machine.target_host()
generate_machine_hardware_info(opts, target_host)
def register_update_hardware_config(parser: argparse.ArgumentParser) -> None: def register_update_hardware_config(parser: argparse.ArgumentParser) -> None:
@@ -184,6 +196,12 @@ def register_update_hardware_config(parser: argparse.ArgumentParser) -> None:
nargs="?", nargs="?",
help="ssh address to install to in the form of user@host:2222", help="ssh address to install to in the form of user@host:2222",
) )
parser.add_argument(
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="ask",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.add_argument( parser.add_argument(
"--password", "--password",
help="Pre-provided password the cli will prompt otherwise if needed.", help="Pre-provided password the cli will prompt otherwise if needed.",

View File

@@ -12,6 +12,7 @@ from clan_lib.cmd import Log, RunOpts, run
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_shell from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
@@ -48,7 +49,7 @@ class InstallOptions:
@API.register @API.register
def install_machine(opts: InstallOptions) -> None: def install_machine(opts: InstallOptions, target_host: Remote) -> None:
machine = opts.machine machine = opts.machine
machine.debug(f"installing {machine.name}") machine.debug(f"installing {machine.name}")
@@ -56,7 +57,6 @@ def install_machine(opts: InstallOptions) -> None:
generate_facts([machine]) generate_facts([machine])
generate_vars([machine]) generate_vars([machine])
host = machine.target_host()
with ( with (
TemporaryDirectory(prefix="nixos-install-") as _base_directory, TemporaryDirectory(prefix="nixos-install-") as _base_directory,
): ):
@@ -127,8 +127,8 @@ def install_machine(opts: InstallOptions) -> None:
if opts.build_on: if opts.build_on:
cmd += ["--build-on", opts.build_on.value] cmd += ["--build-on", opts.build_on.value]
if host.port: if target_host.port:
cmd += ["--ssh-port", str(host.port)] cmd += ["--ssh-port", str(target_host.port)]
if opts.kexec: if opts.kexec:
cmd += ["--kexec", opts.kexec] cmd += ["--kexec", opts.kexec]
@@ -138,7 +138,7 @@ def install_machine(opts: InstallOptions) -> None:
# Add nix options to nixos-anywhere # Add nix options to nixos-anywhere
cmd.extend(opts.nix_options) cmd.extend(opts.nix_options)
cmd.append(host.target) cmd.append(target_host.target)
if opts.use_tor: if opts.use_tor:
# nix copy does not support tor socks proxy # nix copy does not support tor socks proxy
# cmd.append("--ssh-option") # cmd.append("--ssh-option")
@@ -162,7 +162,7 @@ def install_command(args: argparse.Namespace) -> None:
try: try:
# Only if the caller did not specify a target_host via args.target_host # Only if the caller did not specify a target_host via args.target_host
# Find a suitable target_host that is reachable # Find a suitable target_host that is reachable
target_host = args.target_host target_host_str = args.target_host
deploy_info: DeployInfo | None = ssh_command_parse(args) deploy_info: DeployInfo | None = ssh_command_parse(args)
use_tor = False use_tor = False
@@ -170,9 +170,9 @@ def install_command(args: argparse.Namespace) -> None:
host = find_reachable_host(deploy_info) host = find_reachable_host(deploy_info)
if host is None: if host is None:
use_tor = True use_tor = True
target_host = deploy_info.tor.target target_host_str = deploy_info.tor.target
else: else:
target_host = host.target target_host_str = host.target
if args.password: if args.password:
password = args.password password = args.password
@@ -181,12 +181,20 @@ def install_command(args: argparse.Namespace) -> None:
else: else:
password = None password = None
machine = Machine( machine = Machine(name=args.machine, flake=args.flake, nix_options=args.option)
name=args.machine, host_key_check = (
flake=args.flake, HostKeyCheck.from_str(args.host_key_check)
nix_options=args.option, if args.host_key_check
override_target_host=target_host, else HostKeyCheck.ASK
) )
if target_host_str is not None:
target_host = Remote.from_deployment_address(
machine_name=machine.name,
address=target_host_str,
host_key_check=host_key_check,
)
else:
target_host = machine.target_host().with_data(host_key_check=host_key_check)
if machine._class_ == "darwin": if machine._class_ == "darwin":
msg = "Installing macOS machines is not yet supported" msg = "Installing macOS machines is not yet supported"
@@ -217,6 +225,7 @@ def install_command(args: argparse.Namespace) -> None:
identity_file=args.identity_file, identity_file=args.identity_file,
use_tor=use_tor, use_tor=use_tor,
), ),
target_host=target_host,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
log.warning("Interrupted by user") log.warning("Interrupted by user")

View File

@@ -104,10 +104,12 @@ def upload_sources(machine: Machine, ssh: Remote) -> str:
@API.register @API.register
def deploy_machine(machine: Machine) -> None: def deploy_machine(
machine: Machine, target_host: Remote, build_host: Remote | None
) -> None:
with ExitStack() as stack: with ExitStack() as stack:
target_host = stack.enter_context(machine.target_host().ssh_control_master()) target_host = stack.enter_context(target_host.ssh_control_master())
build_host = machine.build_host()
if build_host is not None: if build_host is not None:
build_host = stack.enter_context(build_host.ssh_control_master()) build_host = stack.enter_context(build_host.ssh_control_master())
@@ -198,24 +200,6 @@ def deploy_machine(machine: Machine) -> None:
) )
def deploy_machines(machines: list[Machine]) -> None:
"""
Deploy to all hosts in parallel
"""
with AsyncRuntime() as runtime:
for machine in machines:
runtime.async_run(
AsyncOpts(
tid=machine.name, async_ctx=AsyncContext(prefix=machine.name)
),
deploy_machine,
machine,
)
runtime.join_all()
runtime.check_all()
def update_command(args: argparse.Namespace) -> None: def update_command(args: argparse.Namespace) -> None:
try: try:
if args.flake is None: if args.flake is None:
@@ -237,8 +221,6 @@ def update_command(args: argparse.Namespace) -> None:
name=machine_name, name=machine_name,
flake=args.flake, flake=args.flake,
nix_options=args.option, nix_options=args.option,
override_target_host=args.target_host,
override_build_host=args.build_host,
host_key_check=HostKeyCheck.from_str(args.host_key_check), host_key_check=HostKeyCheck.from_str(args.host_key_check),
) )
machines.append(machine) machines.append(machine)
@@ -285,8 +267,30 @@ def update_command(args: argparse.Namespace) -> None:
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.system.clan.deployment.file", f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.system.clan.deployment.file",
] ]
) )
# Run the deplyoyment
deploy_machines(machines_to_update) host_key_check = HostKeyCheck.from_str(args.host_key_check)
with AsyncRuntime() as runtime:
for machine in machines:
if args.target_host:
target_host = Remote.from_deployment_address(
machine_name=machine.name,
address=args.target_host,
host_key_check=host_key_check,
)
else:
target_host = machine.target_host()
runtime.async_run(
AsyncOpts(
tid=machine.name,
async_ctx=AsyncContext(prefix=machine.name),
),
deploy_machine,
machine=machine,
target_host=target_host,
build_host=machine.build_host(),
)
runtime.join_all()
runtime.check_all()
except KeyboardInterrupt: except KeyboardInterrupt:
log.warning("Interrupted by user") log.warning("Interrupted by user")

View File

@@ -32,8 +32,6 @@ class Machine:
flake: Flake flake: Flake
nix_options: list[str] = field(default_factory=list) nix_options: list[str] = field(default_factory=list)
override_target_host: None | str = None
override_build_host: None | str = None
private_key: Path | None = None private_key: Path | None = None
host_key_check: HostKeyCheck = HostKeyCheck.STRICT host_key_check: HostKeyCheck = HostKeyCheck.STRICT
@@ -143,14 +141,6 @@ class Machine:
return self.flake.path return self.flake.path
def target_host(self) -> Remote: def target_host(self) -> Remote:
if self.override_target_host:
return Remote.from_deployment_address(
machine_name=self.name,
address=self.override_target_host,
host_key_check=self.host_key_check,
private_key=self.private_key,
)
remote = get_host(self.name, self.flake, field="targetHost") remote = get_host(self.name, self.flake, field="targetHost")
if remote is None: if remote is None:
msg = f"'targetHost' is not set for machine '{self.name}'" msg = f"'targetHost' is not set for machine '{self.name}'"
@@ -178,15 +168,6 @@ class Machine:
The host where the machine is built and deployed from. The host where the machine is built and deployed from.
Can be the same as the target host. Can be the same as the target host.
""" """
if self.override_build_host:
return Remote.from_deployment_address(
machine_name=self.name,
address=self.override_build_host,
host_key_check=self.host_key_check,
private_key=self.private_key,
)
remote = get_host(self.name, self.flake, field="buildHost") remote = get_host(self.name, self.flake, field="buildHost")
if remote: if remote:

View File

@@ -54,6 +54,28 @@ class Remote:
except ValueError: except ValueError:
return False return False
def with_data(self, host_key_check: HostKeyCheck | None = None) -> "Remote":
"""
Returns a new Remote instance with the same data but with a different host_key_check.
"""
return Remote(
address=self.address,
user=self.user,
command_prefix=self.command_prefix,
port=self.port,
private_key=self.private_key,
password=self.password,
forward_agent=self.forward_agent,
host_key_check=host_key_check
if host_key_check is not None
else self.host_key_check,
verbose_ssh=self.verbose_ssh,
ssh_options=self.ssh_options,
tor_socks=self.tor_socks,
_control_path_dir=self._control_path_dir,
_askpass_path=self._askpass_path,
)
@property @property
def target(self) -> str: def target(self) -> str:
return f"{self.user}@{self.address}" return f"{self.user}@{self.address}"