From a17baa4861cb93496587b2aa0bd6a1bebed63347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Wed, 9 Aug 2023 16:38:08 +0200 Subject: [PATCH] add test for remote ssh commands --- pkgs/clan-cli/default.nix | 6 +- pkgs/clan-cli/tests/command.py | 60 +++++++++++++ pkgs/clan-cli/tests/conftest.py | 10 ++- pkgs/clan-cli/tests/getpwnam-preload.c | 27 ++++++ pkgs/clan-cli/tests/ports.py | 46 ++++++++++ pkgs/clan-cli/tests/root.py | 4 +- pkgs/clan-cli/tests/sshd.py | 117 +++++++++++++++++++++++++ pkgs/clan-cli/tests/test_ssh_remote.py | 90 +++++++++++++++++++ 8 files changed, 356 insertions(+), 4 deletions(-) create mode 100644 pkgs/clan-cli/tests/command.py create mode 100644 pkgs/clan-cli/tests/getpwnam-preload.c create mode 100644 pkgs/clan-cli/tests/ports.py create mode 100644 pkgs/clan-cli/tests/sshd.py create mode 100644 pkgs/clan-cli/tests/test_ssh_remote.py diff --git a/pkgs/clan-cli/default.nix b/pkgs/clan-cli/default.nix index f9569ef78..1348e325f 100644 --- a/pkgs/clan-cli/default.nix +++ b/pkgs/clan-cli/default.nix @@ -16,6 +16,8 @@ , pytest , pytest-cov , pytest-subprocess +, openssh +, stdenv , wheel }: let @@ -26,6 +28,8 @@ let pytest-cov pytest-subprocess mypy + openssh + stdenv.cc ]; checkPython = python3.withPackages (_ps: dependencies ++ testDependencies); @@ -50,7 +54,7 @@ python3.pkgs.buildPythonPackage { ''; clan-pytest = runCommand "clan-tests" { - nativeBuildInputs = [ age zerotierone bubblewrap sops nix ]; + nativeBuildInputs = [ age zerotierone bubblewrap sops nix openssh stdenv.cc ]; } '' cp -r ${./.} ./src chmod +w -R ./src diff --git a/pkgs/clan-cli/tests/command.py b/pkgs/clan-cli/tests/command.py new file mode 100644 index 000000000..72551ba31 --- /dev/null +++ b/pkgs/clan-cli/tests/command.py @@ -0,0 +1,60 @@ +import os +import signal +import subprocess +from typing import IO, Any, Dict, Iterator, List, Union + +import pytest + +_FILE = Union[None, int, IO[Any]] + + +class Command: + def __init__(self) -> None: + self.processes: List[subprocess.Popen[str]] = [] + + def run( + self, + command: List[str], + extra_env: Dict[str, str] = {}, + stdin: _FILE = None, + stdout: _FILE = None, + stderr: _FILE = None, + ) -> subprocess.Popen[str]: + env = os.environ.copy() + env.update(extra_env) + # We start a new session here so that we can than more reliably kill all childs as well + p = subprocess.Popen( + command, + env=env, + start_new_session=True, + stdout=stdout, + stderr=stderr, + stdin=stdin, + text=True, + ) + self.processes.append(p) + return p + + def terminate(self) -> None: + # Stop in reverse order in case there are dependencies. + # We just kill all processes as quickly as possible because we don't + # care about corrupted state and want to make tests fasts. + for p in reversed(self.processes): + try: + os.killpg(os.getpgid(p.pid), signal.SIGKILL) + except OSError: + pass + + +@pytest.fixture +def command() -> Iterator[Command]: + """ + Starts a background command. The process is automatically terminated in the end. + >>> p = command.run(["some", "daemon"]) + >>> print(p.pid) + """ + c = Command() + try: + yield c + finally: + c.terminate() diff --git a/pkgs/clan-cli/tests/conftest.py b/pkgs/clan-cli/tests/conftest.py index ec743b128..356df9508 100644 --- a/pkgs/clan-cli/tests/conftest.py +++ b/pkgs/clan-cli/tests/conftest.py @@ -3,4 +3,12 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) -pytest_plugins = ["temporary_dir", "clan_flake", "root", "age_keys"] +pytest_plugins = [ + "temporary_dir", + "clan_flake", + "root", + "age_keys", + "sshd", + "command", + "ports", +] diff --git a/pkgs/clan-cli/tests/getpwnam-preload.c b/pkgs/clan-cli/tests/getpwnam-preload.c new file mode 100644 index 000000000..d88aa87f8 --- /dev/null +++ b/pkgs/clan-cli/tests/getpwnam-preload.c @@ -0,0 +1,27 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include + +typedef struct passwd *(*getpwnam_type)(const char *name); + +struct passwd *getpwnam(const char *name) { + struct passwd *pw; + getpwnam_type orig_getpwnam; + orig_getpwnam = (getpwnam_type)dlsym(RTLD_NEXT, "getpwnam"); + pw = orig_getpwnam(name); + + if (pw) { + const char *shell = getenv("LOGIN_SHELL"); + if (!shell) { + fprintf(stderr, "no LOGIN_SHELL set\n"); + exit(1); + } + fprintf(stderr, "SHELL:%s\n", shell); + pw->pw_shell = strdup(shell); + } + return pw; +} diff --git a/pkgs/clan-cli/tests/ports.py b/pkgs/clan-cli/tests/ports.py new file mode 100644 index 000000000..dba5f50ed --- /dev/null +++ b/pkgs/clan-cli/tests/ports.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import socket + +import pytest + +NEXT_PORT = 10000 + + +def check_port(port: int) -> bool: + tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with tcp, udp: + try: + tcp.bind(("127.0.0.1", port)) + udp.bind(("127.0.0.1", port)) + return True + except socket.error: + return False + + +def check_port_range(port_range: range) -> bool: + for port in port_range: + if not check_port(port): + return False + return True + + +class Ports: + def allocate(self, num: int) -> int: + """ + Allocates + """ + global NEXT_PORT + while NEXT_PORT + num <= 65535: + start = NEXT_PORT + NEXT_PORT += num + if not check_port_range(range(start, NEXT_PORT)): + continue + return start + raise Exception("cannot find enough free port") + + +@pytest.fixture +def ports() -> Ports: + return Ports() diff --git a/pkgs/clan-cli/tests/root.py b/pkgs/clan-cli/tests/root.py index 5855b523e..c881ce026 100644 --- a/pkgs/clan-cli/tests/root.py +++ b/pkgs/clan-cli/tests/root.py @@ -6,7 +6,7 @@ TEST_ROOT = Path(__file__).parent.resolve() PROJECT_ROOT = TEST_ROOT.parent -@pytest.fixture +@pytest.fixture(scope="session") def project_root() -> Path: """ Root directory of the tests @@ -14,7 +14,7 @@ def project_root() -> Path: return PROJECT_ROOT -@pytest.fixture +@pytest.fixture(scope="session") def test_root() -> Path: """ Root directory of the tests diff --git a/pkgs/clan-cli/tests/sshd.py b/pkgs/clan-cli/tests/sshd.py new file mode 100644 index 000000000..9b7cfc74b --- /dev/null +++ b/pkgs/clan-cli/tests/sshd.py @@ -0,0 +1,117 @@ +import os +import shutil +import subprocess +import time +from pathlib import Path +from sys import platform +from tempfile import TemporaryDirectory +from typing import Iterator, Optional + +import pytest +from command import Command +from ports import Ports + + +class Sshd: + def __init__(self, port: int, proc: subprocess.Popen[str], key: str) -> None: + self.port = port + self.proc = proc + self.key = key + + +class SshdConfig: + def __init__(self, path: str, key: str, preload_lib: Optional[str]) -> None: + self.path = path + self.key = key + self.preload_lib = preload_lib + + +@pytest.fixture(scope="session") +def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]: + # FIXME, if any parent of `project_root` is world-writable than sshd will refuse it. + with TemporaryDirectory(dir=project_root) as _dir: + dir = Path(_dir) + host_key = dir / "host_ssh_host_ed25519_key" + subprocess.run( + [ + "ssh-keygen", + "-t", + "ed25519", + "-f", + host_key, + "-N", + "", + ], + check=True, + ) + + sshd_config = dir / "sshd_config" + sshd_config.write_text( + f""" + HostKey {host_key} + LogLevel DEBUG3 + # In the nix build sandbox we don't get any meaningful PATH after login + SetEnv PATH={os.environ.get("PATH", "")} + MaxStartups 64:30:256 + AuthorizedKeysFile {host_key}.pub + """ + ) + + lib_path = None + if platform == "linux": + # This enforces a login shell by overriding the login shell of `getpwnam(3)` + lib_path = str(dir / "libgetpwnam-preload.so") + subprocess.run( + [ + os.environ.get("CC", "cc"), + "-shared", + "-o", + lib_path, + str(test_root / "getpwnam-preload.c"), + ], + check=True, + ) + + yield SshdConfig(str(sshd_config), str(host_key), lib_path) + + +@pytest.fixture +def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]: + port = ports.allocate(1) + sshd = shutil.which("sshd") + assert sshd is not None, "no sshd binary found" + env = {} + if sshd_config.preload_lib is not None: + bash = shutil.which("bash") + assert bash is not None + env = dict(LD_PRELOAD=str(sshd_config.preload_lib), LOGIN_SHELL=bash) + proc = command.run( + [sshd, "-f", sshd_config.path, "-D", "-p", str(port)], extra_env=env + ) + + while True: + if ( + subprocess.run( + [ + "ssh", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-i", + sshd_config.key, + "localhost", + "-p", + str(port), + "true", + ] + ).returncode + == 0 + ): + yield Sshd(port, proc, sshd_config.key) + return + else: + rc = proc.poll() + if rc is not None: + raise Exception(f"sshd processes was terminated with {rc}") + time.sleep(0.1) diff --git a/pkgs/clan-cli/tests/test_ssh_remote.py b/pkgs/clan-cli/tests/test_ssh_remote.py new file mode 100644 index 000000000..5885906a8 --- /dev/null +++ b/pkgs/clan-cli/tests/test_ssh_remote.py @@ -0,0 +1,90 @@ +import os +import pwd +import subprocess + +from sshd import Sshd + +from clan_cli.ssh import Group, Host, HostKeyCheck + + +def deploy_group(sshd: Sshd) -> Group: + login = pwd.getpwuid(os.getuid()).pw_name + return Group( + [ + Host( + "127.0.0.1", + port=sshd.port, + user=login, + key=sshd.key, + host_key_check=HostKeyCheck.NONE, + ) + ] + ) + + +def test_run(sshd: Sshd) -> None: + g = deploy_group(sshd) + proc = g.run("echo hello", stdout=subprocess.PIPE) + assert proc[0].result.stdout == "hello\n" + + +def test_run_environment(sshd: Sshd) -> None: + g = deploy_group(sshd) + p1 = g.run("echo $env_var", stdout=subprocess.PIPE, extra_env=dict(env_var="true")) + assert p1[0].result.stdout == "true\n" + p2 = g.run(["env"], stdout=subprocess.PIPE, extra_env=dict(env_var="true")) + assert "env_var=true" in p2[0].result.stdout + + +def test_run_no_shell(sshd: Sshd) -> None: + g = deploy_group(sshd) + proc = g.run(["echo", "$hello"], stdout=subprocess.PIPE) + assert proc[0].result.stdout == "$hello\n" + + +def test_run_function(sshd: Sshd) -> None: + def some_func(h: Host) -> bool: + p = h.run("echo hello", stdout=subprocess.PIPE) + return p.stdout == "hello\n" + + g = deploy_group(sshd) + res = g.run_function(some_func) + assert res[0].result + + +def test_timeout(sshd: Sshd) -> None: + g = deploy_group(sshd) + try: + g.run_local("sleep 10", timeout=0.01) + except Exception: + pass + else: + assert False, "should have raised TimeoutExpired" + + +def test_run_exception(sshd: Sshd) -> None: + g = deploy_group(sshd) + + r = g.run("exit 1", check=False) + assert r[0].result.returncode == 1 + + try: + g.run("exit 1") + except Exception: + pass + else: + assert False, "should have raised Exception" + + +def test_run_function_exception(sshd: Sshd) -> None: + def some_func(h: Host) -> subprocess.CompletedProcess[str]: + return h.run_local("exit 1") + + g = deploy_group(sshd) + + try: + g.run_function(some_func) + except Exception: + pass + else: + assert False, "should have raised Exception"