clan-cli: Simplify HostKeyCheck to a Literal instead of an Enum

This commit is contained in:
Qubasa
2025-06-23 15:08:44 +02:00
parent 65bb9021de
commit 054ea67fb7
10 changed files with 19 additions and 74 deletions

View File

@@ -12,7 +12,7 @@ from clan_lib.errors import ClanCmdError, ClanError
from clan_lib.git import commit_file
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_config, nix_eval
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_lib.ssh.remote import Remote
from clan_cli.completions import add_dynamic_completer, complete_machines
@@ -158,7 +158,7 @@ def generate_machine_hardware_info(
def update_hardware_config_command(args: argparse.Namespace) -> None:
host_key_check = HostKeyCheck.from_str(args.host_key_check)
host_key_check = args.host_key_check
machine = Machine(flake=args.flake, name=args.machine)
opts = HardwareGenerateOptions(
machine=machine,

View File

@@ -12,7 +12,7 @@ from clan_lib.cmd import Log, RunOpts, run
from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_lib.ssh.remote import Remote
from clan_cli.completions import (
add_dynamic_completer,
@@ -182,11 +182,8 @@ def install_command(args: argparse.Namespace) -> None:
password = None
machine = Machine(name=args.machine, flake=args.flake, nix_options=args.option)
host_key_check = (
HostKeyCheck.from_str(args.host_key_check)
if args.host_key_check
else HostKeyCheck.ASK
)
host_key_check = args.host_key_check
if target_host_str is not None:
target_host = Remote.from_deployment_address(
machine_name=machine.name, address=target_host_str

View File

@@ -14,7 +14,7 @@ from clan_lib.colors import AnsiColor
from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_command, nix_config, nix_metadata
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_lib.ssh.remote import Remote
from clan_cli.completions import (
add_dynamic_completer,
@@ -271,7 +271,7 @@ def update_command(args: argparse.Namespace) -> None:
]
)
host_key_check = HostKeyCheck.from_str(args.host_key_check)
host_key_check = args.host_key_check
with AsyncRuntime() as runtime:
for machine in machines:
if args.target_host:

View File

@@ -136,7 +136,7 @@ def ssh_shell_from_deploy(deploy_info: DeployInfo) -> None:
def ssh_command_parse(args: argparse.Namespace) -> DeployInfo | None:
host_key_check = HostKeyCheck.from_str(args.host_key_check)
host_key_check = args.host_key_check
if args.json:
json_file = Path(args.json)
if json_file.is_file():

View File

@@ -1,46 +0,0 @@
# Adapted from https://github.com/numtide/deploykit
from enum import Enum
from clan_lib.errors import ClanError
class HostKeyCheck(Enum):
# Strictly check ssh host keys, prompt for unknown ones
STRICT = 0
# Ask for confirmation on first use
ASK = 1
# Trust on ssh keys on first use
TOFU = 2
# Do not check ssh host keys
NONE = 3
@staticmethod
def from_str(label: str) -> "HostKeyCheck":
if label.upper() in HostKeyCheck.__members__:
return HostKeyCheck[label.upper()]
msg = f"Invalid choice: {label}."
description = "Choose from: " + ", ".join(HostKeyCheck.__members__)
raise ClanError(msg, description=description)
def __str__(self) -> str:
return self.name.lower()
def to_ssh_opt(self) -> list[str]:
match self:
case HostKeyCheck.STRICT:
return ["-o", "StrictHostKeyChecking=yes"]
case HostKeyCheck.ASK:
return []
case HostKeyCheck.TOFU:
return ["-o", "StrictHostKeyChecking=accept-new"]
case HostKeyCheck.NONE:
return [
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
]
case _:
msg = "Invalid HostKeyCheck"
raise ClanError(msg)

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
from clan_lib.cmd import RunOpts, run
from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import HostKeyCheck, Remote
from clan_lib.ssh.remote import Remote
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
@@ -23,7 +23,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
run(cmd, RunOpts(input=data.encode()))
# Call the qrcode_scan function
deploy_info = DeployInfo.from_qr_code(img_path, HostKeyCheck.NONE)
deploy_info = DeployInfo.from_qr_code(img_path, "none")
host = deploy_info.addrs[0]
assert host.address == "192.168.122.86"
@@ -46,7 +46,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
def test_from_json() -> None:
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
deploy_info = DeployInfo.from_json(json.loads(data), HostKeyCheck.NONE)
deploy_info = DeployInfo.from_json(json.loads(data), "none")
host = deploy_info.addrs[0]
assert host.password == "scabbed-defender-headlock"
@@ -69,9 +69,7 @@ def test_from_json() -> None:
@pytest.mark.with_core
def test_find_reachable_host(hosts: list[Remote]) -> None:
host = hosts[0]
deploy_info = DeployInfo.from_hostnames(
["172.19.1.2", host.ssh_url()], HostKeyCheck.NONE
)
deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none")
assert deploy_info.addrs[0].address == "172.19.1.2"

View File

@@ -3,7 +3,6 @@ import pwd
from pathlib import Path
import pytest
from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.tests.sshd import Sshd
from clan_lib.ssh.remote import Remote
@@ -17,7 +16,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
port=sshd.port,
user=login,
private_key=Path(sshd.key),
host_key_check=HostKeyCheck.NONE,
host_key_check="none",
command_prefix="local_test",
)
]