Revert "make host key check an enum instead of an literal type"
This reverts commit 3d83017acd.
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
# Adapted from https://github.com/numtide/deploykit
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from clan_lib.errors import ClanError
|
||||
|
||||
|
||||
class HostKeyCheck(Enum):
|
||||
STRICT = "strict" # Strictly check ssh host keys, prompt for unknown ones
|
||||
ASK = "ask" # Ask for confirmation on first use
|
||||
TOFU = "tofu" # Trust on ssh keys on first use
|
||||
NONE = "none" # Do not check ssh host keys
|
||||
HostKeyCheck = Literal[
|
||||
"strict", # Strictly check ssh host keys, prompt for unknown ones
|
||||
"ask", # Ask for confirmation on first use
|
||||
"tofu", # Trust on ssh keys on first use
|
||||
"none", # Do not check ssh host keys
|
||||
]
|
||||
|
||||
|
||||
def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
|
||||
@@ -17,13 +17,13 @@ def hostkey_to_ssh_opts(host_key_check: HostKeyCheck) -> list[str]:
|
||||
Convert a HostKeyCheck value to SSH options.
|
||||
"""
|
||||
match host_key_check:
|
||||
case HostKeyCheck.STRICT:
|
||||
case "strict":
|
||||
return ["-o", "StrictHostKeyChecking=yes"]
|
||||
case HostKeyCheck.ASK:
|
||||
case "ask":
|
||||
return []
|
||||
case HostKeyCheck.TOFU:
|
||||
case "tofu":
|
||||
return ["-o", "StrictHostKeyChecking=accept-new"]
|
||||
case HostKeyCheck.NONE:
|
||||
case "none":
|
||||
return [
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
|
||||
@@ -38,7 +38,7 @@ class Remote:
|
||||
private_key: Path | None = None
|
||||
password: str | None = None
|
||||
forward_agent: bool = True
|
||||
host_key_check: HostKeyCheck = HostKeyCheck.ASK
|
||||
host_key_check: HostKeyCheck = "ask"
|
||||
verbose_ssh: bool = False
|
||||
ssh_options: dict[str, str] = field(default_factory=dict)
|
||||
tor_socks: bool = False
|
||||
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from clan_lib.async_run import AsyncRuntime
|
||||
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||
from clan_lib.errors import ClanError, CmdOut
|
||||
from clan_lib.ssh.host_key import HostKeyCheck
|
||||
from clan_lib.ssh.remote import Remote
|
||||
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
||||
|
||||
@@ -114,7 +113,7 @@ def test_parse_deployment_address(
|
||||
result = Remote.from_ssh_uri(
|
||||
machine_name=machine_name,
|
||||
address=test_addr,
|
||||
).override(host_key_check=HostKeyCheck.STRICT)
|
||||
).override(host_key_check="strict")
|
||||
|
||||
if expected_exception:
|
||||
return
|
||||
@@ -132,7 +131,7 @@ def test_parse_deployment_address(
|
||||
def test_parse_ssh_options() -> None:
|
||||
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictRemoteKeyChecking=yes"
|
||||
host = Remote.from_ssh_uri(machine_name="foo", address=addr).override(
|
||||
host_key_check=HostKeyCheck.STRICT
|
||||
host_key_check="strict"
|
||||
)
|
||||
assert host.address == "example.com"
|
||||
assert host.port == 2222
|
||||
|
||||
Reference in New Issue
Block a user