make host key check an enum instead of an literal type

this is more typesafe at runtime.
This commit is contained in:
Jörg Thalheim
2025-07-02 17:33:32 +02:00
parent 7f4f11751e
commit 543c518ed0
11 changed files with 63 additions and 44 deletions

View File

@@ -1,15 +1,15 @@
# Adapted from https://github.com/numtide/deploykit
from typing import Literal
from enum import Enum
from clan_lib.errors import ClanError
HostKeyCheck = Literal[
"strict", # Strictly check ssh host keys, prompt for unknown ones
"ask", # Ask for confirmation on first use
"tofu", # Trust on ssh keys on first use
"none", # Do not check ssh host keys
]
class HostKeyCheck(Enum):
STRICT = "strict" # Strictly check ssh host keys, prompt for unknown ones
ASK = "ask" # Ask for confirmation on first use
TOFU = "tofu" # Trust on ssh keys on first use
NONE = "none" # Do not check ssh host keys
def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
@@ -17,13 +17,13 @@ def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
Convert a HostKeyCheck value to SSH options.
"""
match host_key_check:
case "strict":
case HostKeyCheck.STRICT:
return ["-o", "StrictHostKeyChecking=yes"]
case "ask":
case HostKeyCheck.ASK:
return []
case "tofu":
case HostKeyCheck.TOFU:
return ["-o", "StrictHostKeyChecking=accept-new"]
case "none":
case HostKeyCheck.NONE:
return [
"-o",
"StrictHostKeyChecking=no",

View File

@@ -39,7 +39,7 @@ class Remote:
private_key: Path | None = None
password: str | None = None
forward_agent: bool = True
host_key_check: HostKeyCheck = "ask"
host_key_check: HostKeyCheck = HostKeyCheck.ASK
verbose_ssh: bool = False
ssh_options: dict[str, str] = field(default_factory=dict)
tor_socks: bool = False

View File

@@ -9,6 +9,7 @@ from clan_lib.async_run import AsyncRuntime
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_lib.errors import ClanError, CmdOut
from clan_lib.ssh.remote import Remote
from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
if sys.platform == "darwin":
@@ -113,7 +114,7 @@ def test_parse_deployment_address(
result = Remote.from_ssh_uri(
machine_name=machine_name,
address=test_addr,
).override(host_key_check="strict")
).override(host_key_check=HostKeyCheck.STRICT)
if expected_exception:
return
@@ -131,7 +132,7 @@ def test_parse_deployment_address(
def test_parse_ssh_options() -> None:
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
host = Remote.from_ssh_uri(machine_name="foo", address=addr).override(
host_key_check="strict"
host_key_check=HostKeyCheck.STRICT
)
assert host.address == "example.com"
assert host.port == 2222

View File

@@ -33,6 +33,7 @@ from clan_lib.nix_models.clan import (
)
from clan_lib.nix_models.clan import InventoryMachineDeploy as MachineDeploy
from clan_lib.persist.util import set_value_by_path
from clan_lib.ssh.host_key import HostKeyCheck
from clan_lib.ssh.remote import Remote, can_ssh_login
log = logging.getLogger(__name__)
@@ -198,7 +199,7 @@ def test_clan_create_api(
clan_dir_flake.invalidate_cache()
target_host = machine.target_host().override(
private_key=private_key, host_key_check="none"
private_key=private_key, host_key_check=HostKeyCheck.NONE
)
result = can_ssh_login(target_host)
assert result == "Online", f"Machine {machine.name} is not online"