make host key check an enum instead of an literal type
this is more typesafe at runtime.
This commit is contained in:
28
pkgs/clan-cli/clan_cli/host_key_check.py
Normal file
28
pkgs/clan-cli/clan_cli/host_key_check.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Common argument types and utilities for host key checking in clan CLI commands."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
|
|
||||||
|
|
||||||
|
def host_key_check_type(value: str) -> HostKeyCheck:
|
||||||
|
"""
|
||||||
|
Argparse type converter for HostKeyCheck enum.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return HostKeyCheck(value)
|
||||||
|
except ValueError:
|
||||||
|
valid_values = [e.value for e in HostKeyCheck]
|
||||||
|
msg = f"Invalid host key check mode: {value}. Valid options: {', '.join(valid_values)}"
|
||||||
|
raise argparse.ArgumentTypeError(msg) from None
|
||||||
|
|
||||||
|
|
||||||
|
def add_host_key_check_arg(
|
||||||
|
parser: argparse.ArgumentParser, default: HostKeyCheck = HostKeyCheck.ASK
|
||||||
|
) -> None:
|
||||||
|
parser.add_argument(
|
||||||
|
"--host-key-check",
|
||||||
|
type=host_key_check_type,
|
||||||
|
default=default,
|
||||||
|
help=f"Host key (.ssh/known_hosts) check mode. Options: {', '.join([e.value for e in HostKeyCheck])}",
|
||||||
|
)
|
||||||
@@ -12,6 +12,7 @@ from clan_lib.machines.suggestions import validate_machine_names
|
|||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
|
from clan_cli.host_key_check import add_host_key_check_arg
|
||||||
|
|
||||||
from .types import machine_name_type
|
from .types import machine_name_type
|
||||||
|
|
||||||
@@ -54,12 +55,7 @@ def register_update_hardware_config(parser: argparse.ArgumentParser) -> None:
|
|||||||
nargs="?",
|
nargs="?",
|
||||||
help="ssh address to install to in the form of user@host:2222",
|
help="ssh address to install to in the form of user@host:2222",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
add_host_key_check_arg(parser)
|
||||||
"--host-key-check",
|
|
||||||
choices=["strict", "ask", "tofu", "none"],
|
|
||||||
default="ask",
|
|
||||||
help="Host key (.ssh/known_hosts) check mode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--password",
|
"--password",
|
||||||
help="Pre-provided password the cli will prompt otherwise if needed.",
|
help="Pre-provided password the cli will prompt otherwise if needed.",
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from clan_cli.completions import (
|
|||||||
complete_machines,
|
complete_machines,
|
||||||
complete_target_host,
|
complete_target_host,
|
||||||
)
|
)
|
||||||
|
from clan_cli.host_key_check import add_host_key_check_arg
|
||||||
from clan_cli.machines.hardware import HardwareConfig
|
from clan_cli.machines.hardware import HardwareConfig
|
||||||
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host, ssh_command_parse
|
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host, ssh_command_parse
|
||||||
|
|
||||||
@@ -97,12 +98,7 @@ def register_install_parser(parser: argparse.ArgumentParser) -> None:
|
|||||||
help="do not reboot after installation (deprecated)",
|
help="do not reboot after installation (deprecated)",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
add_host_key_check_arg(parser)
|
||||||
"--host-key-check",
|
|
||||||
choices=["strict", "ask", "tofu", "none"],
|
|
||||||
default="ask",
|
|
||||||
help="Host key (.ssh/known_hosts) check mode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--build-on",
|
"--build-on",
|
||||||
choices=[x.value for x in BuildOn],
|
choices=[x.value for x in BuildOn],
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from clan_cli.completions import (
|
|||||||
complete_machines,
|
complete_machines,
|
||||||
complete_tags,
|
complete_tags,
|
||||||
)
|
)
|
||||||
|
from clan_cli.host_key_check import add_host_key_check_arg
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -163,12 +164,7 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None:
|
|||||||
)
|
)
|
||||||
add_dynamic_completer(tag_parser, complete_tags)
|
add_dynamic_completer(tag_parser, complete_tags)
|
||||||
|
|
||||||
parser.add_argument(
|
add_host_key_check_arg(parser)
|
||||||
"--host-key-check",
|
|
||||||
choices=["strict", "ask", "tofu", "none"],
|
|
||||||
default="ask",
|
|
||||||
help="Host key (.ssh/known_hosts) check mode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--target-host",
|
"--target-host",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ from typing import Any
|
|||||||
from clan_lib.cmd import run
|
from clan_lib.cmd import run
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.completions import (
|
from clan_cli.completions import (
|
||||||
add_dynamic_completer,
|
add_dynamic_completer,
|
||||||
complete_machines,
|
complete_machines,
|
||||||
)
|
)
|
||||||
|
from clan_cli.host_key_check import add_host_key_check_arg
|
||||||
from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable
|
from clan_cli.ssh.tor import TorTarget, spawn_tor, ssh_tor_reachable
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -181,10 +183,5 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
|
|||||||
"--png",
|
"--png",
|
||||||
help="specify the json file for ssh data as the qrcode image (generated by starting the clan installer)",
|
help="specify the json file for ssh data as the qrcode image (generated by starting the clan installer)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
add_host_key_check_arg(parser, default=HostKeyCheck.TOFU)
|
||||||
"--host-key-check",
|
|
||||||
choices=["strict", "ask", "tofu", "none"],
|
|
||||||
default="tofu",
|
|
||||||
help="Host key (.ssh/known_hosts) check mode.",
|
|
||||||
)
|
|
||||||
parser.set_defaults(func=ssh_command)
|
parser.set_defaults(func=ssh_command)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from clan_lib.cmd import RunOpts, run
|
from clan_lib.cmd import RunOpts, run
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
|
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
|
||||||
@@ -23,7 +24,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
|
|||||||
run(cmd, RunOpts(input=data.encode()))
|
run(cmd, RunOpts(input=data.encode()))
|
||||||
|
|
||||||
# Call the qrcode_scan function
|
# Call the qrcode_scan function
|
||||||
deploy_info = DeployInfo.from_qr_code(img_path, "none")
|
deploy_info = DeployInfo.from_qr_code(img_path, HostKeyCheck.NONE)
|
||||||
|
|
||||||
host = deploy_info.addrs[0]
|
host = deploy_info.addrs[0]
|
||||||
assert host.address == "192.168.122.86"
|
assert host.address == "192.168.122.86"
|
||||||
@@ -46,7 +47,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
|
|||||||
|
|
||||||
def test_from_json() -> None:
|
def test_from_json() -> None:
|
||||||
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
||||||
deploy_info = DeployInfo.from_json(json.loads(data), "none")
|
deploy_info = DeployInfo.from_json(json.loads(data), HostKeyCheck.NONE)
|
||||||
|
|
||||||
host = deploy_info.addrs[0]
|
host = deploy_info.addrs[0]
|
||||||
assert host.password == "scabbed-defender-headlock"
|
assert host.password == "scabbed-defender-headlock"
|
||||||
@@ -69,7 +70,9 @@ def test_from_json() -> None:
|
|||||||
@pytest.mark.with_core
|
@pytest.mark.with_core
|
||||||
def test_find_reachable_host(hosts: list[Remote]) -> None:
|
def test_find_reachable_host(hosts: list[Remote]) -> None:
|
||||||
host = hosts[0]
|
host = hosts[0]
|
||||||
deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none")
|
deploy_info = DeployInfo.from_hostnames(
|
||||||
|
["172.19.1.2", host.ssh_url()], HostKeyCheck.NONE
|
||||||
|
)
|
||||||
|
|
||||||
assert deploy_info.addrs[0].address == "172.19.1.2"
|
assert deploy_info.addrs[0].address == "172.19.1.2"
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from clan_cli.tests.sshd import Sshd
|
from clan_cli.tests.sshd import Sshd
|
||||||
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
|
|||||||
port=sshd.port,
|
port=sshd.port,
|
||||||
user=login,
|
user=login,
|
||||||
private_key=Path(sshd.key),
|
private_key=Path(sshd.key),
|
||||||
host_key_check="none",
|
host_key_check=HostKeyCheck.NONE,
|
||||||
command_prefix="local_test",
|
command_prefix="local_test",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
# Adapted from https://github.com/numtide/deploykit
|
# Adapted from https://github.com/numtide/deploykit
|
||||||
|
|
||||||
from typing import Literal
|
from enum import Enum
|
||||||
|
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
|
|
||||||
HostKeyCheck = Literal[
|
|
||||||
"strict", # Strictly check ssh host keys, prompt for unknown ones
|
class HostKeyCheck(Enum):
|
||||||
"ask", # Ask for confirmation on first use
|
STRICT = "strict" # Strictly check ssh host keys, prompt for unknown ones
|
||||||
"tofu", # Trust on ssh keys on first use
|
ASK = "ask" # Ask for confirmation on first use
|
||||||
"none", # Do not check ssh host keys
|
TOFU = "tofu" # Trust on ssh keys on first use
|
||||||
]
|
NONE = "none" # Do not check ssh host keys
|
||||||
|
|
||||||
|
|
||||||
def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
|
def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
|
||||||
@@ -17,13 +17,13 @@ def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
|
|||||||
Convert a HostKeyCheck value to SSH options.
|
Convert a HostKeyCheck value to SSH options.
|
||||||
"""
|
"""
|
||||||
match host_key_check:
|
match host_key_check:
|
||||||
case "strict":
|
case HostKeyCheck.STRICT:
|
||||||
return ["-o", "StrictHostKeyChecking=yes"]
|
return ["-o", "StrictHostKeyChecking=yes"]
|
||||||
case "ask":
|
case HostKeyCheck.ASK:
|
||||||
return []
|
return []
|
||||||
case "tofu":
|
case HostKeyCheck.TOFU:
|
||||||
return ["-o", "StrictHostKeyChecking=accept-new"]
|
return ["-o", "StrictHostKeyChecking=accept-new"]
|
||||||
case "none":
|
case HostKeyCheck.NONE:
|
||||||
return [
|
return [
|
||||||
"-o",
|
"-o",
|
||||||
"StrictHostKeyChecking=no",
|
"StrictHostKeyChecking=no",
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class Remote:
|
|||||||
private_key: Path | None = None
|
private_key: Path | None = None
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
forward_agent: bool = True
|
forward_agent: bool = True
|
||||||
host_key_check: HostKeyCheck = "ask"
|
host_key_check: HostKeyCheck = HostKeyCheck.ASK
|
||||||
verbose_ssh: bool = False
|
verbose_ssh: bool = False
|
||||||
ssh_options: dict[str, str] = field(default_factory=dict)
|
ssh_options: dict[str, str] = field(default_factory=dict)
|
||||||
tor_socks: bool = False
|
tor_socks: bool = False
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from clan_lib.async_run import AsyncRuntime
|
|||||||
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||||
from clan_lib.errors import ClanError, CmdOut
|
from clan_lib.errors import ClanError, CmdOut
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
@@ -113,7 +114,7 @@ def test_parse_deployment_address(
|
|||||||
result = Remote.from_ssh_uri(
|
result = Remote.from_ssh_uri(
|
||||||
machine_name=machine_name,
|
machine_name=machine_name,
|
||||||
address=test_addr,
|
address=test_addr,
|
||||||
).override(host_key_check="strict")
|
).override(host_key_check=HostKeyCheck.STRICT)
|
||||||
|
|
||||||
if expected_exception:
|
if expected_exception:
|
||||||
return
|
return
|
||||||
@@ -131,7 +132,7 @@ def test_parse_deployment_address(
|
|||||||
def test_parse_ssh_options() -> None:
|
def test_parse_ssh_options() -> None:
|
||||||
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
|
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
|
||||||
host = Remote.from_ssh_uri(machine_name="foo", address=addr).override(
|
host = Remote.from_ssh_uri(machine_name="foo", address=addr).override(
|
||||||
host_key_check="strict"
|
host_key_check=HostKeyCheck.STRICT
|
||||||
)
|
)
|
||||||
assert host.address == "example.com"
|
assert host.address == "example.com"
|
||||||
assert host.port == 2222
|
assert host.port == 2222
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from clan_lib.nix_models.clan import (
|
|||||||
)
|
)
|
||||||
from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy
|
from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy
|
||||||
from clan_lib.persist.util import set_value_by_path
|
from clan_lib.persist.util import set_value_by_path
|
||||||
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
from clan_lib.ssh.remote import Remote, can_ssh_login
|
from clan_lib.ssh.remote import Remote, can_ssh_login
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -198,7 +199,7 @@ def test_clan_create_api(
|
|||||||
clan_dir_flake.invalidate_cache()
|
clan_dir_flake.invalidate_cache()
|
||||||
|
|
||||||
target_host = machine.target_host().override(
|
target_host = machine.target_host().override(
|
||||||
private_key=private_key, host_key_check="none"
|
private_key=private_key, host_key_check=HostKeyCheck.NONE
|
||||||
)
|
)
|
||||||
result = can_ssh_login(target_host)
|
result = can_ssh_login(target_host)
|
||||||
assert result == "Online", f"Machine {machine.name} is not online"
|
assert result == "Online", f"Machine {machine.name} is not online"
|
||||||
|
|||||||
Reference in New Issue
Block a user