clan-cli: Add --host-key-check to machine update

This commit is contained in:
Qubasa
2024-10-05 23:33:44 +02:00
parent 8df6ed40b5
commit 7bd50b03b3
4 changed files with 30 additions and 8 deletions

View File

@@ -13,7 +13,7 @@ from clan_cli.errors import ClanError
from clan_cli.facts import public_modules as facts_public_modules from clan_cli.facts import public_modules as facts_public_modules
from clan_cli.facts import secret_modules as facts_secret_modules from clan_cli.facts import secret_modules as facts_secret_modules
from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata from clan_cli.nix import nix_build, nix_config, nix_eval, nix_metadata
from clan_cli.ssh import Host, parse_deployment_address from clan_cli.ssh import Host, HostKeyCheck, parse_deployment_address
from clan_cli.vars.public_modules import FactStoreBase from clan_cli.vars.public_modules import FactStoreBase
from clan_cli.vars.secret_modules import SecretStoreBase from clan_cli.vars.secret_modules import SecretStoreBase
@@ -27,6 +27,7 @@ class Machine:
nix_options: list[str] = field(default_factory=list) nix_options: list[str] = field(default_factory=list)
cached_deployment: None | dict[str, Any] = None cached_deployment: None | dict[str, Any] = None
override_target_host: None | str = None override_target_host: None | str = None
host_key_check: HostKeyCheck = HostKeyCheck.STRICT
_eval_cache: dict[str, str] = field(default_factory=dict) _eval_cache: dict[str, str] = field(default_factory=dict)
_build_cache: dict[str, Path] = field(default_factory=dict) _build_cache: dict[str, Path] = field(default_factory=dict)
@@ -143,7 +144,10 @@ class Machine:
@property @property
def target_host(self) -> Host: def target_host(self) -> Host:
return parse_deployment_address( return parse_deployment_address(
self.name, self.target_host_address, meta={"machine": self} self.name,
self.target_host_address,
self.host_key_check,
meta={"machine": self},
) )
@property @property
@@ -159,6 +163,7 @@ class Machine:
return parse_deployment_address( return parse_deployment_address(
self.name, self.name,
build_host, build_host,
self.host_key_check,
forward_agent=True, forward_agent=True,
meta={"machine": self, "target_host": self.target_host}, meta={"machine": self, "target_host": self.target_host},
) )

View File

@@ -171,6 +171,7 @@ def update(args: argparse.Namespace) -> None:
name=args.machines[0], flake=args.flake, nix_options=args.option name=args.machines[0], flake=args.flake, nix_options=args.option
) )
machine.override_target_host = args.target_host machine.override_target_host = args.target_host
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
machines.append(machine) machines.append(machine)
elif args.target_host is not None: elif args.target_host is not None:
@@ -187,7 +188,7 @@ def update(args: argparse.Namespace) -> None:
except ClanError: # check if we have a build host set except ClanError: # check if we have a build host set
ignored_machines.append(machine) ignored_machines.append(machine)
continue continue
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
machines.append(machine) machines.append(machine)
if not machines and ignored_machines != []: if not machines and ignored_machines != []:
@@ -201,8 +202,8 @@ def update(args: argparse.Namespace) -> None:
else: else:
machines = get_selected_machines(args.flake, args.option, args.machines) machines = get_selected_machines(args.flake, args.option, args.machines)
group = MachineGroup(machines) host_group = MachineGroup(machines)
deploy_machine(group) deploy_machine(host_group)
def register_update_parser(parser: argparse.ArgumentParser) -> None: def register_update_parser(parser: argparse.ArgumentParser) -> None:
@@ -216,6 +217,12 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None:
) )
add_dynamic_completer(machines_parser, complete_machines) add_dynamic_completer(machines_parser, complete_machines)
parser.add_argument(
"--host-key-check",
choices=["strict", "tofu", "none"],
default="strict",
help="Host key (.ssh/known_hosts) check mode",
)
parser.add_argument( parser.add_argument(
"--target-host", "--target-host",

View File

@@ -133,6 +133,14 @@ class HostKeyCheck(Enum):
# Do not check ssh host keys # Do not check ssh host keys
NONE = 2 NONE = 2
@staticmethod
def from_str(label: str) -> "HostKeyCheck":
if label.upper() in HostKeyCheck.__members__:
return HostKeyCheck[label.upper()]
msg = f"Invalid choice: {label}."
description = "Choose from: " + ", ".join(HostKeyCheck.__members__)
raise ClanError(msg, description=description)
class Host: class Host:
def __init__( def __init__(
@@ -790,6 +798,7 @@ class HostGroup:
def parse_deployment_address( def parse_deployment_address(
machine_name: str, machine_name: str,
host: str, host: str,
host_key_check: HostKeyCheck,
forward_agent: bool = True, forward_agent: bool = True,
meta: dict[str, Any] | None = None, meta: dict[str, Any] | None = None,
) -> Host: ) -> Host:
@@ -820,6 +829,7 @@ def parse_deployment_address(
hostname, hostname,
user=user, user=user,
port=port, port=port,
host_key_check=host_key_check,
command_prefix=machine_name, command_prefix=machine_name,
forward_agent=forward_agent, forward_agent=forward_agent,
meta=meta, meta=meta,

View File

@@ -1,13 +1,13 @@
import subprocess import subprocess
from clan_cli.ssh import Host, HostGroup, parse_deployment_address from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
def test_parse_ipv6() -> None: def test_parse_ipv6() -> None:
host = parse_deployment_address("foo", "[fe80::1%eth0]:2222") host = parse_deployment_address("foo", "[fe80::1%eth0]:2222", HostKeyCheck.STRICT)
assert host.host == "fe80::1%eth0" assert host.host == "fe80::1%eth0"
assert host.port == 2222 assert host.port == 2222
host = parse_deployment_address("foo", "[fe80::1%eth0]") host = parse_deployment_address("foo", "[fe80::1%eth0]", HostKeyCheck.STRICT)
assert host.host == "fe80::1%eth0" assert host.host == "fe80::1%eth0"
assert host.port is None assert host.port is None