From 543c518ed0573a23883aa1f0f0be0163bb098add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Wed, 2 Jul 2025 17:33:32 +0200 Subject: [PATCH] make host key check an enum instead of an literal type this is more typesafe at runtime. --- pkgs/clan-cli/clan_cli/host_key_check.py | 28 +++++++++++++++++++ pkgs/clan-cli/clan_cli/machines/hardware.py | 8 ++---- pkgs/clan-cli/clan_cli/machines/install.py | 8 ++---- pkgs/clan-cli/clan_cli/machines/update.py | 8 ++---- pkgs/clan-cli/clan_cli/ssh/deploy_info.py | 11 +++----- .../clan-cli/clan_cli/ssh/test_deploy_info.py | 9 ++++-- pkgs/clan-cli/clan_cli/tests/hosts.py | 3 +- pkgs/clan-cli/clan_lib/ssh/host_key.py | 22 +++++++-------- pkgs/clan-cli/clan_lib/ssh/remote.py | 2 +- pkgs/clan-cli/clan_lib/ssh/remote_test.py | 5 ++-- pkgs/clan-cli/clan_lib/tests/test_create.py | 3 +- 11 files changed, 63 insertions(+), 44 deletions(-) create mode 100644 pkgs/clan-cli/clan_cli/host_key_check.py diff --git a/pkgs/clan-cli/clan_cli/host_key_check.py b/pkgs/clan-cli/clan_cli/host_key_check.py new file mode 100644 index 000000000..df1331574 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/host_key_check.py @@ -0,0 +1,28 @@ +"""Common argument types and utilities for host key checking in clan CLI commands.""" + +import argparse + +from clan_lib.ssh.host_key import HostKeyCheck + + +def host_key_check_type(value: str) -> HostKeyCheck: + """ + Argparse type converter for HostKeyCheck enum. + """ + try: + return HostKeyCheck(value) + except ValueError: + valid_values = [e.value for e in HostKeyCheck] + msg = f"Invalid host key check mode: {value}. Valid options: {', '.join(valid_values)}" + raise argparse.ArgumentTypeError(msg) from None + + +def add_host_key_check_arg( + parser: argparse.ArgumentParser, default: HostKeyCheck = HostKeyCheck.ASK +) -> None: + parser.add_argument( + "--host-key-check", + type=host_key_check_type, + default=default, + help=f"Host key (.ssh/known_hosts) check mode. Options: {', '.join([e.value for e in HostKeyCheck])}", + ) diff --git a/pkgs/clan-cli/clan_cli/machines/hardware.py b/pkgs/clan-cli/clan_cli/machines/hardware.py index b65cf7831..2db751fd3 100644 --- a/pkgs/clan-cli/clan_cli/machines/hardware.py +++ b/pkgs/clan-cli/clan_cli/machines/hardware.py @@ -12,6 +12,7 @@ from clan_lib.machines.suggestions import validate_machine_names from clan_lib.ssh.remote import Remote from clan_cli.completions import add_dynamic_completer, complete_machines +from clan_cli.host_key_check import add_host_key_check_arg from .types import machine_name_type @@ -54,12 +55,7 @@ def register_update_hardware_config(parser: argparse.ArgumentParser) -> None: nargs="?", help="ssh address to install to in the form of user@host:2222", ) - parser.add_argument( - "--host-key-check", - choices=["strict", "ask", "tofu", "none"], - default="ask", - help="Host key (.ssh/known_hosts) check mode.", - ) + add_host_key_check_arg(parser) parser.add_argument( "--password", help="Pre-provided password the cli will prompt otherwise if needed.", diff --git a/pkgs/clan-cli/clan_cli/machines/install.py b/pkgs/clan-cli/clan_cli/machines/install.py index b93a7331d..a6222e88f 100644 --- a/pkgs/clan-cli/clan_cli/machines/install.py +++ b/pkgs/clan-cli/clan_cli/machines/install.py @@ -13,6 +13,7 @@ from clan_cli.completions import ( complete_machines, complete_target_host, ) +from clan_cli.host_key_check import add_host_key_check_arg from clan_cli.machines.hardware import HardwareConfig from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host, ssh_command_parse @@ -97,12 +98,7 @@ def register_install_parser(parser: argparse.ArgumentParser) -> None: help="do not reboot after installation (deprecated)", default=False, ) - parser.add_argument( - "--host-key-check", - choices=["strict", "ask", "tofu", "none"], - default="ask", - help="Host key (.ssh/known_hosts) check mode.", - ) + add_host_key_check_arg(parser) parser.add_argument( "--build-on", choices=[x.value for x in BuildOn], diff --git a/pkgs/clan-cli/clan_cli/machines/update.py b/pkgs/clan-cli/clan_cli/machines/update.py index 19ed3412b..76f5afb38 100644 --- a/pkgs/clan-cli/clan_cli/machines/update.py +++ b/pkgs/clan-cli/clan_cli/machines/update.py @@ -16,6 +16,7 @@ from clan_cli.completions import ( complete_machines, complete_tags, ) +from clan_cli.host_key_check import add_host_key_check_arg log = logging.getLogger(__name__) @@ -163,12 +164,7 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None: ) add_dynamic_completer(tag_parser, complete_tags) - parser.add_argument( - "--host-key-check", - choices=["strict", "ask", "tofu", "none"], - default="ask", - help="Host key (.ssh/known_hosts) check mode.", - ) + add_host_key_check_arg(parser) parser.add_argument( "--target-host", type=str, diff --git a/pkgs/clan-cli/clan_cli/ssh/deploy_info.py b/pkgs/clan-cli/clan_cli/ssh/deploy_info.py index 9ba609c95..3734b47a7 100644 --- a/pkgs/clan-cli/clan_cli/ssh/deploy_info.py +++ b/pkgs/clan-cli/clan_cli/ssh/deploy_info.py @@ -8,12 +8,14 @@ from typing import Any from clan_lib.cmd import run from clan_lib.errors import ClanError from clan_lib.nix import nix_shell -from clan_lib.ssh.remote import HostKeyCheck, Remote +from clan_lib.ssh.host_key import HostKeyCheck +from clan_lib.ssh.remote import Remote from clan_cli.completions import ( add_dynamic_completer, complete_machines, ) +from clan_cli.host_key_check import add_host_key_check_arg from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable log = logging.getLogger(__name__) @@ -181,10 +183,5 @@ def register_parser(parser: argparse.ArgumentParser) -> None: "--png", help="specify the json file for ssh data as the qrcode image (generated by starting the clan installer)", ) - parser.add_argument( - "--host-key-check", - choices=["strict", "ask", "tofu", "none"], - default="tofu", - help="Host key (.ssh/known_hosts) check mode.", - ) + add_host_key_check_arg(parser, default=HostKeyCheck.TOFU) parser.set_defaults(func=ssh_command) diff --git a/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py b/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py index 3a4996f7d..5efa7cbfb 100644 --- a/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py +++ b/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest from clan_lib.cmd import RunOpts, run from clan_lib.nix import nix_shell +from clan_lib.ssh.host_key import HostKeyCheck from clan_lib.ssh.remote import Remote from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host @@ -23,7 +24,7 @@ def test_qrcode_scan(temp_dir: Path) -> None: run(cmd, RunOpts(input=data.encode())) # Call the qrcode_scan function - deploy_info = DeployInfo.from_qr_code(img_path, "none") + deploy_info = DeployInfo.from_qr_code(img_path, HostKeyCheck.NONE) host = deploy_info.addrs[0] assert host.address == "192.168.122.86" @@ -46,7 +47,7 @@ def test_qrcode_scan(temp_dir: Path) -> None: def test_from_json() -> None: data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}' - deploy_info = DeployInfo.from_json(json.loads(data), "none") + deploy_info = DeployInfo.from_json(json.loads(data), HostKeyCheck.NONE) host = deploy_info.addrs[0] assert host.password == "scabbed-defender-headlock" @@ -69,7 +70,9 @@ def test_from_json() -> None: @pytest.mark.with_core def test_find_reachable_host(hosts: list[Remote]) -> None: host = hosts[0] - deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none") + deploy_info = DeployInfo.from_hostnames( + ["172.19.1.2", host.ssh_url()], HostKeyCheck.NONE + ) assert deploy_info.addrs[0].address == "172.19.1.2" diff --git a/pkgs/clan-cli/clan_cli/tests/hosts.py b/pkgs/clan-cli/clan_cli/tests/hosts.py index 84c32e720..f40c79e63 100644 --- a/pkgs/clan-cli/clan_cli/tests/hosts.py +++ b/pkgs/clan-cli/clan_cli/tests/hosts.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest from clan_cli.tests.sshd import Sshd +from clan_lib.ssh.host_key import HostKeyCheck from clan_lib.ssh.remote import Remote @@ -16,7 +17,7 @@ def hosts(sshd: Sshd) -> list[Remote]: port=sshd.port, user=login, private_key=Path(sshd.key), - host_key_check="none", + host_key_check=HostKeyCheck.NONE, command_prefix="local_test", ) ] diff --git a/pkgs/clan-cli/clan_lib/ssh/host_key.py b/pkgs/clan-cli/clan_lib/ssh/host_key.py index 3f2e6968c..e97932ef3 100644 --- a/pkgs/clan-cli/clan_lib/ssh/host_key.py +++ b/pkgs/clan-cli/clan_lib/ssh/host_key.py @@ -1,15 +1,15 @@ # Adapted from https://github.com/numtide/deploykit -from typing import Literal +from enum import Enum from clan_lib.errors import ClanError -HostKeyCheck = Literal[ - "strict", # Strictly check ssh host keys, prompt for unknown ones - "ask", # Ask for confirmation on first use - "tofu", # Trust on ssh keys on first use - "none", # Do not check ssh host keys -] + +class HostKeyCheck(Enum): + STRICT = "strict" # Strictly check ssh host keys, prompt for unknown ones + ASK = "ask" # Ask for confirmation on first use + TOFU = "tofu" # Trust on ssh keys on first use + NONE = "none" # Do not check ssh host keys def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]: @@ -17,13 +17,13 @@ def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]: Convert a HostKeyCheck value to SSH options. """ match host_key_check: - case "strict": + case HostKeyCheck.STRICT: return ["-o", "StrictHostKeyChecking=yes"] - case "ask": + case HostKeyCheck.ASK: return [] - case "tofu": + case HostKeyCheck.TOFU: return ["-o", "StrictHostKeyChecking=accept-new"] - case "none": + case HostKeyCheck.NONE: return [ "-o", "StrictHostKeyChecking=no", diff --git a/pkgs/clan-cli/clan_lib/ssh/remote.py b/pkgs/clan-cli/clan_lib/ssh/remote.py index de486d80f..41db26ad5 100644 --- a/pkgs/clan-cli/clan_lib/ssh/remote.py +++ b/pkgs/clan-cli/clan_lib/ssh/remote.py @@ -39,7 +39,7 @@ class Remote: private_key: Path | None = None password: str | None = None forward_agent: bool = True - host_key_check: HostKeyCheck = "ask" + host_key_check: HostKeyCheck = HostKeyCheck.ASK verbose_ssh: bool = False ssh_options: dict[str, str] = field(default_factory=dict) tor_socks: bool = False diff --git a/pkgs/clan-cli/clan_lib/ssh/remote_test.py b/pkgs/clan-cli/clan_lib/ssh/remote_test.py index 6d8a094b7..00b6ca53c 100644 --- a/pkgs/clan-cli/clan_lib/ssh/remote_test.py +++ b/pkgs/clan-cli/clan_lib/ssh/remote_test.py @@ -9,6 +9,7 @@ from clan_lib.async_run import AsyncRuntime from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts from clan_lib.errors import ClanError, CmdOut from clan_lib.ssh.remote import Remote +from clan_lib.ssh.host_key import HostKeyCheck from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy if sys.platform == "darwin": @@ -113,7 +114,7 @@ def test_parse_deployment_address( result = Remote.from_ssh_uri( machine_name=machine_name, address=test_addr, - ).override(host_key_check="strict") + ).override(host_key_check=HostKeyCheck.STRICT) if expected_exception: return @@ -131,7 +132,7 @@ def test_parse_deployment_address( def test_parse_ssh_options() -> None: addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes" host = Remote.from_ssh_uri(machine_name="foo", address=addr).override( - host_key_check="strict" + host_key_check=HostKeyCheck.STRICT ) assert host.address == "example.com" assert host.port == 2222 diff --git a/pkgs/clan-cli/clan_lib/tests/test_create.py b/pkgs/clan-cli/clan_lib/tests/test_create.py index 7090e8b85..ccbdce0ec 100644 --- a/pkgs/clan-cli/clan_lib/tests/test_create.py +++ b/pkgs/clan-cli/clan_lib/tests/test_create.py @@ -33,6 +33,7 @@ from clan_lib.nix_models.clan import ( ) from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy from clan_lib.persist.util import set_value_by_path +from clan_lib.ssh.host_key import HostKeyCheck from clan_lib.ssh.remote import Remote, can_ssh_login log = logging.getLogger(__name__) @@ -198,7 +199,7 @@ def test_clan_create_api( clan_dir_flake.invalidate_cache() target_host = machine.target_host().override( - private_key=private_key, host_key_check="none" + private_key=private_key, host_key_check=HostKeyCheck.NONE ) result = can_ssh_login(target_host) assert result == "Online", f"Machine {machine.name} is not online"