make host key check an enum instead of an literal type

this is more typesafe at runtime.
This commit is contained in:
Jörg Thalheim
2025-07-02 17:33:32 +02:00
parent 7f744079a6
commit 3d83017acd
11 changed files with 63 additions and 44 deletions

View File

@@ -0,0 +1,28 @@
"""Common argument types and utilities for host key checking in clan CLI commands."""
import argparse
from clan_lib.ssh.host_key import HostKeyCheck
def host_key_check_type(value: str) -> HostKeyCheck:
"""
Argparse type converter for HostKeyCheck enum.
"""
try:
return HostKeyCheck(value)
except ValueError:
valid_values = [e.value for e in HostKeyCheck]
msg = f"Invalid host key check mode: {value}. Valid options: {', '.join(valid_values)}"
raise argparse.ArgumentTypeError(msg) from None
def add_host_key_check_arg(
parser: argparse.ArgumentParser, default: HostKeyCheck = HostKeyCheck.ASK
) -> None:
parser.add_argument(
"--host-key-check",
type=host_key_check_type,
default=default,
help=f"Host key (.ssh/known_hosts) check mode. Options: {', '.join([e.value for e in HostKeyCheck])}",
)

View File

@@ -12,6 +12,7 @@ from clan_lib.machines.suggestions import validate_machine_names
from clan_lib.ssh.remote import Remote from clan_lib.ssh.remote import Remote
from clan_cli.completions import add_dynamic_completer, complete_machines from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_cli.host_key_check import add_host_key_check_arg
from .types import machine_name_type from .types import machine_name_type
@@ -54,12 +55,7 @@ def register_update_hardware_config(parser: argparse.ArgumentParser) -> None:
nargs="?", nargs="?",
help="ssh address to install to in the form of user@host:2222", help="ssh address to install to in the form of user@host:2222",
) )
parser.add_argument( add_host_key_check_arg(parser)
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="ask",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.add_argument( parser.add_argument(
"--password", "--password",
help="Pre-provided password the cli will prompt otherwise if needed.", help="Pre-provided password the cli will prompt otherwise if needed.",

View File

@@ -13,6 +13,7 @@ from clan_cli.completions import (
complete_machines, complete_machines,
complete_target_host, complete_target_host,
) )
from clan_cli.host_key_check import add_host_key_check_arg
from clan_cli.machines.hardware import HardwareConfig from clan_cli.machines.hardware import HardwareConfig
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host, ssh_command_parse from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host, ssh_command_parse
@@ -97,12 +98,7 @@ def register_install_parser(parser: argparse.ArgumentParser) -> None:
help="do not reboot after installation (deprecated)", help="do not reboot after installation (deprecated)",
default=False, default=False,
) )
parser.add_argument( add_host_key_check_arg(parser)
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="ask",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.add_argument( parser.add_argument(
"--build-on", "--build-on",
choices=[x.value for x in BuildOn], choices=[x.value for x in BuildOn],

View File

@@ -16,6 +16,7 @@ from clan_cli.completions import (
complete_machines, complete_machines,
complete_tags, complete_tags,
) )
from clan_cli.host_key_check import add_host_key_check_arg
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -163,12 +164,7 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None:
) )
add_dynamic_completer(tag_parser, complete_tags) add_dynamic_completer(tag_parser, complete_tags)
parser.add_argument( add_host_key_check_arg(parser)
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="ask",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.add_argument( parser.add_argument(
"--target-host", "--target-host",
type=str, type=str,

View File

@@ -8,12 +8,14 @@ from typing import Any
from clan_lib.cmd import run from clan_lib.cmd import run
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
from clan_lib.nix import nix_shell from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import HostKeyCheck, Remote from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.remote import Remote
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
complete_machines, complete_machines,
) )
from clan_cli.host_key_check import add_host_key_check_arg
from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -181,10 +183,5 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
"--png", "--png",
help="specify the json file for ssh data as the qrcode image (generated by starting the clan installer)", help="specify the json file for ssh data as the qrcode image (generated by starting the clan installer)",
) )
parser.add_argument( add_host_key_check_arg(parser, default=HostKeyCheck.TOFU)
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="tofu",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.set_defaults(func=ssh_command) parser.set_defaults(func=ssh_command)

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import pytest import pytest
from clan_lib.cmd import RunOpts, run from clan_lib.cmd import RunOpts, run
from clan_lib.nix import nix_shell from clan_lib.nix import nix_shell
from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.remote import Remote from clan_lib.ssh.remote import Remote
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
@@ -23,7 +24,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
run(cmd, RunOpts(input=data.encode())) run(cmd, RunOpts(input=data.encode()))
# Call the qrcode_scan function # Call the qrcode_scan function
deploy_info = DeployInfo.from_qr_code(img_path, "none") deploy_info = DeployInfo.from_qr_code(img_path, HostKeyCheck.NONE)
host = deploy_info.addrs[0] host = deploy_info.addrs[0]
assert host.address == "192.168.122.86" assert host.address == "192.168.122.86"
@@ -46,7 +47,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
def test_from_json() -> None: def test_from_json() -> None:
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}' data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
deploy_info = DeployInfo.from_json(json.loads(data), "none") deploy_info = DeployInfo.from_json(json.loads(data), HostKeyCheck.NONE)
host = deploy_info.addrs[0] host = deploy_info.addrs[0]
assert host.password == "scabbed-defender-headlock" assert host.password == "scabbed-defender-headlock"
@@ -69,7 +70,9 @@ def test_from_json() -> None:
@pytest.mark.with_core @pytest.mark.with_core
def test_find_reachable_host(hosts: list[Remote]) -> None: def test_find_reachable_host(hosts: list[Remote]) -> None:
host = hosts[0] host = hosts[0]
deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none") deploy_info = DeployInfo.from_hostnames(
["172.19.1.2", host.ssh_url()], HostKeyCheck.NONE
)
assert deploy_info.addrs[0].address == "172.19.1.2" assert deploy_info.addrs[0].address == "172.19.1.2"

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import pytest import pytest
from clan_cli.tests.sshd import Sshd from clan_cli.tests.sshd import Sshd
from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.remote import Remote from clan_lib.ssh.remote import Remote
@@ -16,7 +17,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
port=sshd.port, port=sshd.port,
user=login, user=login,
private_key=Path(sshd.key), private_key=Path(sshd.key),
host_key_check="none", host_key_check=HostKeyCheck.NONE,
command_prefix="local_test", command_prefix="local_test",
) )
] ]

View File

@@ -1,15 +1,15 @@
# Adapted from https://github.com/numtide/deploykit # Adapted from https://github.com/numtide/deploykit
from typing import Literal from enum import Enum
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
HostKeyCheck = Literal[
"strict", # Strictly check ssh host keys, prompt for unknown ones class HostKeyCheck(Enum):
"ask", # Ask for confirmation on first use STRICT = "strict" # Strictly check ssh host keys, prompt for unknown ones
"tofu", # Trust on ssh keys on first use ASK = "ask" # Ask for confirmation on first use
"none", # Do not check ssh host keys TOFU = "tofu" # Trust on ssh keys on first use
] NONE = "none" # Do not check ssh host keys
def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]: def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
@@ -17,13 +17,13 @@ def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
Convert a HostKeyCheck value to SSH options. Convert a HostKeyCheck value to SSH options.
""" """
match host_key_check: match host_key_check:
case "strict": case HostKeyCheck.STRICT:
return ["-o", "StrictHostKeyChecking=yes"] return ["-o", "StrictHostKeyChecking=yes"]
case "ask": case HostKeyCheck.ASK:
return [] return []
case "tofu": case HostKeyCheck.TOFU:
return ["-o", "StrictHostKeyChecking=accept-new"] return ["-o", "StrictHostKeyChecking=accept-new"]
case "none": case HostKeyCheck.NONE:
return [ return [
"-o", "-o",
"StrictHostKeyChecking=no", "StrictHostKeyChecking=no",

View File

@@ -39,7 +39,7 @@ class Remote:
private_key: Path | None = None private_key: Path | None = None
password: str | None = None password: str | None = None
forward_agent: bool = True forward_agent: bool = True
host_key_check: HostKeyCheck = "ask" host_key_check: HostKeyCheck = HostKeyCheck.ASK
verbose_ssh: bool = False verbose_ssh: bool = False
ssh_options: dict[str, str] = field(default_factory=dict) ssh_options: dict[str, str] = field(default_factory=dict)
tor_socks: bool = False tor_socks: bool = False

View File

@@ -9,6 +9,7 @@ from clan_lib.async_run import AsyncRuntime
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_lib.errors import ClanError, CmdOut from clan_lib.errors import ClanError, CmdOut
from clan_lib.ssh.remote import Remote from clan_lib.ssh.remote import Remote
from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
if sys.platform == "darwin": if sys.platform == "darwin":
@@ -113,7 +114,7 @@ def test_parse_deployment_address(
result = Remote.from_ssh_uri( result = Remote.from_ssh_uri(
machine_name=machine_name, machine_name=machine_name,
address=test_addr, address=test_addr,
).override(host_key_check="strict") ).override(host_key_check=HostKeyCheck.STRICT)
if expected_exception: if expected_exception:
return return
@@ -131,7 +132,7 @@ def test_parse_deployment_address(
def test_parse_ssh_options() -> None: def test_parse_ssh_options() -> None:
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes" addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
host = Remote.from_ssh_uri(machine_name="foo", address=addr).override( host = Remote.from_ssh_uri(machine_name="foo", address=addr).override(
host_key_check="strict" host_key_check=HostKeyCheck.STRICT
) )
assert host.address == "example.com" assert host.address == "example.com"
assert host.port == 2222 assert host.port == 2222

View File

@@ -33,6 +33,7 @@ from clan_lib.nix_models.clan import (
) )
from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy
from clan_lib.persist.util import set_value_by_path from clan_lib.persist.util import set_value_by_path
from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.remote import Remote, can_ssh_login from clan_lib.ssh.remote import Remote, can_ssh_login
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -198,7 +199,7 @@ def test_clan_create_api(
clan_dir_flake.invalidate_cache() clan_dir_flake.invalidate_cache()
target_host = machine.target_host().override( target_host = machine.target_host().override(
private_key=private_key, host_key_check="none" private_key=private_key, host_key_check=HostKeyCheck.NONE
) )
result = can_ssh_login(target_host) result = can_ssh_login(target_host)
assert result == "Online", f"Machine {machine.name} is not online" assert result == "Online", f"Machine {machine.name} is not online"