clan-cli: Simplify HostKeyCheck to a Literal instead of an Enum
This commit is contained in:
@@ -12,7 +12,7 @@ from clan_lib.errors import ClanCmdError, ClanError
|
|||||||
from clan_lib.git import commit_file
|
from clan_lib.git import commit_file
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.nix import nix_config, nix_eval
|
from clan_lib.nix import nix_config, nix_eval
|
||||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
|
|
||||||
@@ -158,7 +158,7 @@ def generate_machine_hardware_info(
|
|||||||
|
|
||||||
|
|
||||||
def update_hardware_config_command(args: argparse.Namespace) -> None:
|
def update_hardware_config_command(args: argparse.Namespace) -> None:
|
||||||
host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
host_key_check = args.host_key_check
|
||||||
machine = Machine(flake=args.flake, name=args.machine)
|
machine = Machine(flake=args.flake, name=args.machine)
|
||||||
opts = HardwareGenerateOptions(
|
opts = HardwareGenerateOptions(
|
||||||
machine=machine,
|
machine=machine,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from clan_lib.cmd import Log, RunOpts, run
|
|||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.completions import (
|
from clan_cli.completions import (
|
||||||
add_dynamic_completer,
|
add_dynamic_completer,
|
||||||
@@ -182,11 +182,8 @@ def install_command(args: argparse.Namespace) -> None:
|
|||||||
password = None
|
password = None
|
||||||
|
|
||||||
machine = Machine(name=args.machine, flake=args.flake, nix_options=args.option)
|
machine = Machine(name=args.machine, flake=args.flake, nix_options=args.option)
|
||||||
host_key_check = (
|
host_key_check = args.host_key_check
|
||||||
HostKeyCheck.from_str(args.host_key_check)
|
|
||||||
if args.host_key_check
|
|
||||||
else HostKeyCheck.ASK
|
|
||||||
)
|
|
||||||
if target_host_str is not None:
|
if target_host_str is not None:
|
||||||
target_host = Remote.from_deployment_address(
|
target_host = Remote.from_deployment_address(
|
||||||
machine_name=machine.name, address=target_host_str
|
machine_name=machine.name, address=target_host_str
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from clan_lib.colors import AnsiColor
|
|||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.nix import nix_command, nix_config, nix_metadata
|
from clan_lib.nix import nix_command, nix_config, nix_metadata
|
||||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.completions import (
|
from clan_cli.completions import (
|
||||||
add_dynamic_completer,
|
add_dynamic_completer,
|
||||||
@@ -271,7 +271,7 @@ def update_command(args: argparse.Namespace) -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
host_key_check = args.host_key_check
|
||||||
with AsyncRuntime() as runtime:
|
with AsyncRuntime() as runtime:
|
||||||
for machine in machines:
|
for machine in machines:
|
||||||
if args.target_host:
|
if args.target_host:
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ def ssh_shell_from_deploy(deploy_info: DeployInfo) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def ssh_command_parse(args: argparse.Namespace) -> DeployInfo | None:
|
def ssh_command_parse(args: argparse.Namespace) -> DeployInfo | None:
|
||||||
host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
host_key_check = args.host_key_check
|
||||||
if args.json:
|
if args.json:
|
||||||
json_file = Path(args.json)
|
json_file = Path(args.json)
|
||||||
if json_file.is_file():
|
if json_file.is_file():
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
# Adapted from https://github.com/numtide/deploykit
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from clan_lib.errors import ClanError
|
|
||||||
|
|
||||||
|
|
||||||
class HostKeyCheck(Enum):
|
|
||||||
# Strictly check ssh host keys, prompt for unknown ones
|
|
||||||
STRICT = 0
|
|
||||||
# Ask for confirmation on first use
|
|
||||||
ASK = 1
|
|
||||||
# Trust on ssh keys on first use
|
|
||||||
TOFU = 2
|
|
||||||
# Do not check ssh host keys
|
|
||||||
NONE = 3
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_str(label: str) -> "HostKeyCheck":
|
|
||||||
if label.upper() in HostKeyCheck.__members__:
|
|
||||||
return HostKeyCheck[label.upper()]
|
|
||||||
msg = f"Invalid choice: {label}."
|
|
||||||
description = "Choose from: " + ", ".join(HostKeyCheck.__members__)
|
|
||||||
raise ClanError(msg, description=description)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.name.lower()
|
|
||||||
|
|
||||||
def to_ssh_opt(self) -> list[str]:
|
|
||||||
match self:
|
|
||||||
case HostKeyCheck.STRICT:
|
|
||||||
return ["-o", "StrictHostKeyChecking=yes"]
|
|
||||||
case HostKeyCheck.ASK:
|
|
||||||
return []
|
|
||||||
case HostKeyCheck.TOFU:
|
|
||||||
return ["-o", "StrictHostKeyChecking=accept-new"]
|
|
||||||
case HostKeyCheck.NONE:
|
|
||||||
return [
|
|
||||||
"-o",
|
|
||||||
"StrictHostKeyChecking=no",
|
|
||||||
"-o",
|
|
||||||
"UserKnownHostsFile=/dev/null",
|
|
||||||
]
|
|
||||||
case _:
|
|
||||||
msg = "Invalid HostKeyCheck"
|
|
||||||
raise ClanError(msg)
|
|
||||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from clan_lib.cmd import RunOpts, run
|
from clan_lib.cmd import RunOpts, run
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
|
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
|
|||||||
run(cmd, RunOpts(input=data.encode()))
|
run(cmd, RunOpts(input=data.encode()))
|
||||||
|
|
||||||
# Call the qrcode_scan function
|
# Call the qrcode_scan function
|
||||||
deploy_info = DeployInfo.from_qr_code(img_path, HostKeyCheck.NONE)
|
deploy_info = DeployInfo.from_qr_code(img_path, "none")
|
||||||
|
|
||||||
host = deploy_info.addrs[0]
|
host = deploy_info.addrs[0]
|
||||||
assert host.address == "192.168.122.86"
|
assert host.address == "192.168.122.86"
|
||||||
@@ -46,7 +46,7 @@ def test_qrcode_scan(temp_dir: Path) -> None:
|
|||||||
|
|
||||||
def test_from_json() -> None:
|
def test_from_json() -> None:
|
||||||
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
||||||
deploy_info = DeployInfo.from_json(json.loads(data), HostKeyCheck.NONE)
|
deploy_info = DeployInfo.from_json(json.loads(data), "none")
|
||||||
|
|
||||||
host = deploy_info.addrs[0]
|
host = deploy_info.addrs[0]
|
||||||
assert host.password == "scabbed-defender-headlock"
|
assert host.password == "scabbed-defender-headlock"
|
||||||
@@ -69,9 +69,7 @@ def test_from_json() -> None:
|
|||||||
@pytest.mark.with_core
|
@pytest.mark.with_core
|
||||||
def test_find_reachable_host(hosts: list[Remote]) -> None:
|
def test_find_reachable_host(hosts: list[Remote]) -> None:
|
||||||
host = hosts[0]
|
host = hosts[0]
|
||||||
deploy_info = DeployInfo.from_hostnames(
|
deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none")
|
||||||
["172.19.1.2", host.ssh_url()], HostKeyCheck.NONE
|
|
||||||
)
|
|
||||||
|
|
||||||
assert deploy_info.addrs[0].address == "172.19.1.2"
|
assert deploy_info.addrs[0].address == "172.19.1.2"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import pwd
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
|
||||||
from clan_cli.tests.sshd import Sshd
|
from clan_cli.tests.sshd import Sshd
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
@@ -17,7 +16,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
|
|||||||
port=sshd.port,
|
port=sshd.port,
|
||||||
user=login,
|
user=login,
|
||||||
private_key=Path(sshd.key),
|
private_key=Path(sshd.key),
|
||||||
host_key_check=HostKeyCheck.NONE,
|
host_key_check="none",
|
||||||
command_prefix="local_test",
|
command_prefix="local_test",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -15,13 +15,12 @@ from shlex import quote
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
|
||||||
|
|
||||||
from clan_lib.api import API
|
from clan_lib.api import API
|
||||||
from clan_lib.cmd import ClanCmdError, ClanCmdTimeoutError, CmdOut, RunOpts, run
|
from clan_lib.cmd import ClanCmdError, ClanCmdTimeoutError, CmdOut, RunOpts, run
|
||||||
from clan_lib.colors import AnsiColor
|
from clan_lib.colors import AnsiColor
|
||||||
from clan_lib.errors import ClanError # Assuming these are available
|
from clan_lib.errors import ClanError # Assuming these are available
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
|
from clan_lib.ssh.host_key import HostKeyCheck, hostkey_to_ssh_opts
|
||||||
from clan_lib.ssh.parse import parse_deployment_address
|
from clan_lib.ssh.parse import parse_deployment_address
|
||||||
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
||||||
|
|
||||||
@@ -40,7 +39,7 @@ class Remote:
|
|||||||
private_key: Path | None = None
|
private_key: Path | None = None
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
forward_agent: bool = True
|
forward_agent: bool = True
|
||||||
host_key_check: HostKeyCheck = HostKeyCheck.ASK
|
host_key_check: HostKeyCheck = "ask"
|
||||||
verbose_ssh: bool = False
|
verbose_ssh: bool = False
|
||||||
ssh_options: dict[str, str] = field(default_factory=dict)
|
ssh_options: dict[str, str] = field(default_factory=dict)
|
||||||
tor_socks: bool = False
|
tor_socks: bool = False
|
||||||
@@ -334,7 +333,7 @@ class Remote:
|
|||||||
ssh_opts.extend(["-p", str(self.port)])
|
ssh_opts.extend(["-p", str(self.port)])
|
||||||
for k, v in self.ssh_options.items():
|
for k, v in self.ssh_options.items():
|
||||||
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
|
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
|
||||||
ssh_opts.extend(self.host_key_check.to_ssh_opt())
|
ssh_opts.extend(hostkey_to_ssh_opts(self.host_key_check))
|
||||||
if self.private_key:
|
if self.private_key:
|
||||||
ssh_opts.extend(["-i", str(self.private_key)])
|
ssh_opts.extend(["-i", str(self.private_key)])
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from collections.abc import Generator
|
|||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
|
||||||
|
|
||||||
from clan_lib.async_run import AsyncRuntime
|
from clan_lib.async_run import AsyncRuntime
|
||||||
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||||
@@ -114,7 +113,7 @@ def test_parse_deployment_address(
|
|||||||
result = Remote.from_deployment_address(
|
result = Remote.from_deployment_address(
|
||||||
machine_name=machine_name,
|
machine_name=machine_name,
|
||||||
address=test_addr,
|
address=test_addr,
|
||||||
).override(host_key_check=HostKeyCheck.STRICT)
|
).override(host_key_check="strict")
|
||||||
|
|
||||||
if expected_exception:
|
if expected_exception:
|
||||||
return
|
return
|
||||||
@@ -132,7 +131,7 @@ def test_parse_deployment_address(
|
|||||||
def test_parse_ssh_options() -> None:
|
def test_parse_ssh_options() -> None:
|
||||||
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
|
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
|
||||||
host = Remote.from_deployment_address(machine_name="foo", address=addr).override(
|
host = Remote.from_deployment_address(machine_name="foo", address=addr).override(
|
||||||
host_key_check=HostKeyCheck.STRICT
|
host_key_check="strict"
|
||||||
)
|
)
|
||||||
assert host.address == "example.com"
|
assert host.address == "example.com"
|
||||||
assert host.port == 2222
|
assert host.port == 2222
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from clan_cli.machines.create import create_machine
|
|||||||
from clan_cli.secrets.key import generate_key
|
from clan_cli.secrets.key import generate_key
|
||||||
from clan_cli.secrets.sops import maybe_get_admin_public_keys
|
from clan_cli.secrets.sops import maybe_get_admin_public_keys
|
||||||
from clan_cli.secrets.users import add_user
|
from clan_cli.secrets.users import add_user
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
|
||||||
from clan_cli.vars.generate import generate_vars_for_machine, get_generators_closure
|
from clan_cli.vars.generate import generate_vars_for_machine, get_generators_closure
|
||||||
|
|
||||||
from clan_lib.api.disk import hw_main_disk_options, set_machine_disk_schema
|
from clan_lib.api.disk import hw_main_disk_options, set_machine_disk_schema
|
||||||
@@ -198,7 +197,7 @@ def test_clan_create_api(
|
|||||||
clan_dir_flake.invalidate_cache()
|
clan_dir_flake.invalidate_cache()
|
||||||
|
|
||||||
target_host = machine.target_host().override(
|
target_host = machine.target_host().override(
|
||||||
private_key=private_key, host_key_check=HostKeyCheck.NONE
|
private_key=private_key, host_key_check="none"
|
||||||
)
|
)
|
||||||
result = can_ssh_login(target_host)
|
result = can_ssh_login(target_host)
|
||||||
assert result == "Online", f"Machine {machine.name} is not online"
|
assert result == "Online", f"Machine {machine.name} is not online"
|
||||||
|
|||||||
Reference in New Issue
Block a user