Files
clan-core/pkgs/clan-cli/tests/test_ssh_remote.py
2025-03-11 11:27:04 +01:00

220 lines
6.4 KiB
Python

import contextlib
from collections.abc import Generator
from typing import Any, NamedTuple
import pytest
from clan_cli.async_run import AsyncRuntime
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_cli.errors import ClanError, CmdOut
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.parse import parse_deployment_address
class ParseTestCase(NamedTuple):
test_addr: str = ""
expected_host: str = ""
expected_port: int | None = None
expected_user: str = ""
expected_options: dict[str, str] = {} # noqa: RUF012
expected_exception: type[Exception] | None = None
parse_deployment_address_test_cases = (
(
"host_only",
ParseTestCase(test_addr="example.com", expected_host="example.com"),
),
(
"host_user_port",
ParseTestCase(
test_addr="user@example.com:22",
expected_host="example.com",
expected_user="user",
expected_port=22,
),
),
(
"cannot_parse_user_host_port",
ParseTestCase(test_addr="foo@bar@wat", expected_exception=ClanError),
),
(
"missing_hostname",
ParseTestCase(test_addr="foo@:2222", expected_exception=ClanError),
),
(
"invalid_ipv6",
ParseTestCase(test_addr="user@fe80::1%eth0", expected_exception=ClanError),
),
(
"valid_ipv6_without_port",
ParseTestCase(test_addr="[fe80::1%eth0]", expected_host="fe80::1%eth0"),
),
(
"valid_ipv6_with_port",
ParseTestCase(
test_addr="[fe80::1%eth0]:222",
expected_host="fe80::1%eth0",
expected_port=222,
),
),
(
"empty_options",
ParseTestCase(test_addr="example.com?", expected_host="example.com"),
),
(
"option_with_missing_value",
ParseTestCase(test_addr="example.com?foo", expected_exception=ClanError),
),
(
"options_with_@",
ParseTestCase(
test_addr="user@example.com?ProxyJump=root@foo&IdentityFile=/key",
expected_host="example.com",
expected_user="user",
expected_options={
"IdentityFile": "/key",
"ProxyJump": "root@foo",
},
),
),
)
@pytest.mark.parametrize(
argnames=ParseTestCase._fields,
argvalues=(case for _, case in parse_deployment_address_test_cases),
ids=(name for name, _ in parse_deployment_address_test_cases),
)
def test_parse_deployment_address(
test_addr: str,
expected_host: str,
expected_port: int | None,
expected_user: str,
expected_options: dict[str, str],
expected_exception: type[Exception] | None,
) -> None:
if expected_exception:
maybe_check_exception = pytest.raises(expected_exception)
else:
@contextlib.contextmanager
def noop() -> Generator[None, Any, None]:
yield
maybe_check_exception = noop() # type: ignore
with maybe_check_exception:
machine_name = "foo"
result = parse_deployment_address(machine_name, test_addr, HostKeyCheck.STRICT)
if expected_exception:
return
assert result.host == expected_host
assert result.port == expected_port
assert result.user == expected_user
assert result.ssh_options == expected_options
def test_parse_ssh_options() -> None:
addr = "root@example.com:2222?IdentityFile=/path/to/private/key&StrictHostKeyChecking=yes"
host = parse_deployment_address("foo", addr, HostKeyCheck.STRICT)
assert host.host == "example.com"
assert host.port == 2222
assert host.user == "root"
assert host.ssh_options["IdentityFile"] == "/path/to/private/key"
assert host.ssh_options["StrictHostKeyChecking"] == "yes"
def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
)
assert proc.wait().result.stdout == "hello\n"
def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None,
host.run_local,
["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert proc.wait().result.stdout == "true\n"
for host in hosts:
p2 = runtime.async_run(
None,
host.run_local,
["env"],
RunOpts(log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert "env_var=true" in p2.wait().result.stdout
def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
)
assert proc.wait().result.stdout == "hello\n"
def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None:
def some_func(h: Host) -> bool:
p = h.run(["echo", "hello"])
return p.stdout == "hello\n"
for host in hosts:
proc = runtime.async_run(None, some_func, host)
assert proc.wait().result
def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01)
)
error = proc.wait().error
assert isinstance(error, ClanCmdTimeoutError)
def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False)
)
assert proc.wait().result.returncode == 1
try:
for host in hosts:
runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
runtime.join_all()
runtime.check_all()
except Exception:
pass
else:
msg = "should have raised Exception"
raise AssertionError(msg)
def test_run_function_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
def some_func(h: Host) -> CmdOut:
return h.run_local(["exit 1"], RunOpts(shell=True))
try:
for host in hosts:
runtime.async_run(None, some_func, host)
runtime.join_all()
runtime.check_all()
except Exception:
pass
else:
msg = "should have raised Exception"
raise AssertionError(msg)