diff --git a/pkgs/clan-cli/clan_cli/machines/hardware.py b/pkgs/clan-cli/clan_cli/machines/hardware.py index 4815f46f0..657735463 100644 --- a/pkgs/clan-cli/clan_cli/machines/hardware.py +++ b/pkgs/clan-cli/clan_cli/machines/hardware.py @@ -12,7 +12,7 @@ from clan_lib.errors import ClanCmdError, ClanError from clan_lib.git import commit_file from clan_lib.machines.machines import Machine from clan_lib.nix import nix_config, nix_eval -from clan_lib.ssh.remote import HostKeyCheck, Remote +from clan_lib.ssh.remote import Remote from clan_cli.completions import add_dynamic_completer, complete_machines @@ -158,7 +158,7 @@ def generate_machine_hardware_info( def update_hardware_config_command(args: argparse.Namespace) -> None: - host_key_check = HostKeyCheck.from_str(args.host_key_check) + host_key_check = args.host_key_check machine = Machine(flake=args.flake, name=args.machine) opts = HardwareGenerateOptions( machine=machine, diff --git a/pkgs/clan-cli/clan_cli/machines/install.py b/pkgs/clan-cli/clan_cli/machines/install.py index d54e58f1e..739a3cea5 100644 --- a/pkgs/clan-cli/clan_cli/machines/install.py +++ b/pkgs/clan-cli/clan_cli/machines/install.py @@ -12,7 +12,7 @@ from clan_lib.cmd import Log, RunOpts, run from clan_lib.errors import ClanError from clan_lib.machines.machines import Machine from clan_lib.nix import nix_shell -from clan_lib.ssh.remote import HostKeyCheck, Remote +from clan_lib.ssh.remote import Remote from clan_cli.completions import ( add_dynamic_completer, @@ -182,11 +182,8 @@ def install_command(args: argparse.Namespace) -> None: password = None machine = Machine(name=args.machine, flake=args.flake, nix_options=args.option) - host_key_check = ( - HostKeyCheck.from_str(args.host_key_check) - if args.host_key_check - else HostKeyCheck.ASK - ) + host_key_check = args.host_key_check + if target_host_str is not None: target_host = Remote.from_deployment_address( machine_name=machine.name, address=target_host_str diff --git a/pkgs/clan-cli/clan_cli/machines/update.py b/pkgs/clan-cli/clan_cli/machines/update.py index 23d4d9bb2..37ecc3e28 100644 --- a/pkgs/clan-cli/clan_cli/machines/update.py +++ b/pkgs/clan-cli/clan_cli/machines/update.py @@ -14,7 +14,7 @@ from clan_lib.colors import AnsiColor from clan_lib.errors import ClanError from clan_lib.machines.machines import Machine from clan_lib.nix import nix_command, nix_config, nix_metadata -from clan_lib.ssh.remote import HostKeyCheck, Remote +from clan_lib.ssh.remote import Remote from clan_cli.completions import ( add_dynamic_completer, @@ -271,7 +271,7 @@ def update_command(args: argparse.Namespace) -> None: ] ) - host_key_check = HostKeyCheck.from_str(args.host_key_check) + host_key_check = args.host_key_check with AsyncRuntime() as runtime: for machine in machines: if args.target_host: diff --git a/pkgs/clan-cli/clan_cli/ssh/deploy_info.py b/pkgs/clan-cli/clan_cli/ssh/deploy_info.py index 12e96a9e2..d12e10065 100644 --- a/pkgs/clan-cli/clan_cli/ssh/deploy_info.py +++ b/pkgs/clan-cli/clan_cli/ssh/deploy_info.py @@ -136,7 +136,7 @@ def ssh_shell_from_deploy(deploy_info: DeployInfo) -> None: def ssh_command_parse(args: argparse.Namespace) -> DeployInfo | None: - host_key_check = HostKeyCheck.from_str(args.host_key_check) + host_key_check = args.host_key_check if args.json: json_file = Path(args.json) if json_file.is_file(): diff --git a/pkgs/clan-cli/clan_cli/ssh/host_key.py b/pkgs/clan-cli/clan_cli/ssh/host_key.py deleted file mode 100644 index 1caf84e9d..000000000 --- a/pkgs/clan-cli/clan_cli/ssh/host_key.py +++ /dev/null @@ -1,46 +0,0 @@ -# Adapted from https://github.com/numtide/deploykit - -from enum import Enum - -from clan_lib.errors import ClanError - - -class HostKeyCheck(Enum): - # Strictly check ssh host keys, prompt for unknown ones - STRICT = 0 - # Ask for confirmation on first use - ASK = 1 - # Trust on ssh keys on first use - TOFU = 2 - # Do not check ssh host keys - NONE = 3 - - @staticmethod - def from_str(label: str) -> "HostKeyCheck": - if label.upper() in HostKeyCheck.__members__: - return HostKeyCheck[label.upper()] - msg = f"Invalid choice: {label}." - description = "Choose from: " + ", ".join(HostKeyCheck.__members__) - raise ClanError(msg, description=description) - - def __str__(self) -> str: - return self.name.lower() - - def to_ssh_opt(self) -> list[str]: - match self: - case HostKeyCheck.STRICT: - return ["-o", "StrictHostKeyChecking=yes"] - case HostKeyCheck.ASK: - return [] - case HostKeyCheck.TOFU: - return ["-o", "StrictHostKeyChecking=accept-new"] - case HostKeyCheck.NONE: - return [ - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", - ] - case _: - msg = "Invalid HostKeyCheck" - raise ClanError(msg) diff --git a/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py b/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py index 859d580a3..3a4996f7d 100644 --- a/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py +++ b/pkgs/clan-cli/clan_cli/ssh/test_deploy_info.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest from clan_lib.cmd import RunOpts, run from clan_lib.nix import nix_shell -from clan_lib.ssh.remote import HostKeyCheck, Remote +from clan_lib.ssh.remote import Remote from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host @@ -23,7 +23,7 @@ def test_qrcode_scan(temp_dir: Path) -> None: run(cmd, RunOpts(input=data.encode())) # Call the qrcode_scan function - deploy_info = DeployInfo.from_qr_code(img_path, HostKeyCheck.NONE) + deploy_info = DeployInfo.from_qr_code(img_path, "none") host = deploy_info.addrs[0] assert host.address == "192.168.122.86" @@ -46,7 +46,7 @@ def test_qrcode_scan(temp_dir: Path) -> None: def test_from_json() -> None: data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}' - deploy_info = DeployInfo.from_json(json.loads(data), HostKeyCheck.NONE) + deploy_info = DeployInfo.from_json(json.loads(data), "none") host = deploy_info.addrs[0] assert host.password == "scabbed-defender-headlock" @@ -69,9 +69,7 @@ def test_from_json() -> None: @pytest.mark.with_core def test_find_reachable_host(hosts: list[Remote]) -> None: host = hosts[0] - deploy_info = DeployInfo.from_hostnames( - ["172.19.1.2", host.ssh_url()], HostKeyCheck.NONE - ) + deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none") assert deploy_info.addrs[0].address == "172.19.1.2" diff --git a/pkgs/clan-cli/clan_cli/tests/hosts.py b/pkgs/clan-cli/clan_cli/tests/hosts.py index 4da85848e..84c32e720 100644 --- a/pkgs/clan-cli/clan_cli/tests/hosts.py +++ b/pkgs/clan-cli/clan_cli/tests/hosts.py @@ -3,7 +3,6 @@ import pwd from pathlib import Path import pytest -from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.tests.sshd import Sshd from clan_lib.ssh.remote import Remote @@ -17,7 +16,7 @@ def hosts(sshd: Sshd) -> list[Remote]: port=sshd.port, user=login, private_key=Path(sshd.key), - host_key_check=HostKeyCheck.NONE, + host_key_check="none", command_prefix="local_test", ) ] diff --git a/pkgs/clan-cli/clan_lib/ssh/remote.py b/pkgs/clan-cli/clan_lib/ssh/remote.py index a547e7367..6732f2456 100644 --- a/pkgs/clan-cli/clan_lib/ssh/remote.py +++ b/pkgs/clan-cli/clan_lib/ssh/remote.py @@ -15,13 +15,12 @@ from shlex import quote from tempfile import TemporaryDirectory from typing import Literal -from clan_cli.ssh.host_key import HostKeyCheck - from clan_lib.api import API from clan_lib.cmd import ClanCmdError, ClanCmdTimeoutError, CmdOut, RunOpts, run from clan_lib.colors import AnsiColor from clan_lib.errors import ClanError # Assuming these are available from clan_lib.nix import nix_shell +from clan_lib.ssh.host_key import HostKeyCheck, hostkey_to_ssh_opts from clan_lib.ssh.parse import parse_deployment_address from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy @@ -40,7 +39,7 @@ class Remote: private_key: Path | None = None password: str | None = None forward_agent: bool = True - host_key_check: HostKeyCheck = HostKeyCheck.ASK + host_key_check: HostKeyCheck = "ask" verbose_ssh: bool = False ssh_options: dict[str, str] = field(default_factory=dict) tor_socks: bool = False @@ -334,7 +333,7 @@ class Remote: ssh_opts.extend(["-p", str(self.port)]) for k, v in self.ssh_options.items(): ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"]) - ssh_opts.extend(self.host_key_check.to_ssh_opt()) + ssh_opts.extend(hostkey_to_ssh_opts(self.host_key_check)) if self.private_key: ssh_opts.extend(["-i", str(self.private_key)]) diff --git a/pkgs/clan-cli/clan_lib/ssh/remote_test.py b/pkgs/clan-cli/clan_lib/ssh/remote_test.py index 34baa952e..dae2f623c 100644 --- a/pkgs/clan-cli/clan_lib/ssh/remote_test.py +++ b/pkgs/clan-cli/clan_lib/ssh/remote_test.py @@ -4,7 +4,6 @@ from collections.abc import Generator from typing import Any, NamedTuple import pytest -from clan_cli.ssh.host_key import HostKeyCheck from clan_lib.async_run import AsyncRuntime from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts @@ -114,7 +113,7 @@ def test_parse_deployment_address( result = Remote.from_deployment_address( machine_name=machine_name, address=test_addr, - ).override(host_key_check=HostKeyCheck.STRICT) + ).override(host_key_check="strict") if expected_exception: return @@ -132,7 +131,7 @@ def test_parse_deployment_address( def test_parse_ssh_options() -> None: addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes" host = Remote.from_deployment_address(machine_name="foo", address=addr).override( - host_key_check=HostKeyCheck.STRICT + host_key_check="strict" ) assert host.address == "example.com" assert host.port == 2222 diff --git a/pkgs/clan-cli/clan_lib/tests/test_create.py b/pkgs/clan-cli/clan_lib/tests/test_create.py index a3a3a185d..ddb697d8d 100644 --- a/pkgs/clan-cli/clan_lib/tests/test_create.py +++ b/pkgs/clan-cli/clan_lib/tests/test_create.py @@ -13,7 +13,6 @@ from clan_cli.machines.create import create_machine from clan_cli.secrets.key import generate_key from clan_cli.secrets.sops import maybe_get_admin_public_keys from clan_cli.secrets.users import add_user -from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.vars.generate import generate_vars_for_machine, get_generators_closure from clan_lib.api.disk import hw_main_disk_options, set_machine_disk_schema @@ -198,7 +197,7 @@ def test_clan_create_api( clan_dir_flake.invalidate_cache() target_host = machine.target_host().override( - private_key=private_key, host_key_check=HostKeyCheck.NONE + private_key=private_key, host_key_check="none" ) result = can_ssh_login(target_host) assert result == "Online", f"Machine {machine.name} is not online"