diff --git a/nixosModules/clanCore/zerotier/generate.py b/nixosModules/clanCore/zerotier/generate.py index 75d969472..1df7ea168 100644 --- a/nixosModules/clanCore/zerotier/generate.py +++ b/nixosModules/clanCore/zerotier/generate.py @@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]: str(home), ] - with subprocess.Popen( - cmd, - preexec_fn=os.setsid, - ) as p: + with subprocess.Popen(cmd, start_new_session=True) as p: process_group = os.getpgid(p.pid) try: print( diff --git a/pkgs/clan-cli/clan_cli/cmd.py b/pkgs/clan-cli/clan_cli/cmd.py index d95e7db0e..e1395bf13 100644 --- a/pkgs/clan-cli/clan_cli/cmd.py +++ b/pkgs/clan-cli/clan_cli/cmd.py @@ -1,20 +1,25 @@ -import datetime +import contextlib import logging import os import select import shlex +import signal import subprocess import sys +import timeit import weakref -from datetime import timedelta +from collections.abc import Iterator +from contextlib import contextmanager from enum import Enum from pathlib import Path from typing import IO, Any +from clan_cli.errors import ClanError + from .custom_logger import get_caller from .errors import ClanCmdError, CmdOut -glog = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class Log(Enum): @@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]: return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace") +@contextmanager +def terminate_process_group(process: subprocess.Popen) -> Iterator[None]: + process_group = os.getpgid(process.pid) + if process_group == os.getpgid(os.getpid()): + msg = "Bug! Refusing to terminate the current process group" + raise ClanError(msg) + try: + yield + finally: + try: + os.killpg(process_group, signal.SIGTERM) + try: + with contextlib.suppress(subprocess.TimeoutExpired): + # give the process time to terminate + process.wait(3) + finally: + os.killpg(process_group, signal.SIGKILL) + except ProcessLookupError: # process already terminated + pass + + class TimeTable: """ This class is used to store the time taken by each command @@ -67,7 +93,7 @@ class TimeTable: """ def __init__(self) -> None: - self.table: dict[str, timedelta] = {} + self.table: dict[str, float] = {} weakref.finalize(self, self.table_print) def table_print(self) -> None: @@ -80,14 +106,14 @@ class TimeTable: for k, v in sorted_table: # Check if timedelta is greater than 1 second - if v.total_seconds() > 1: + if v > 1: # Print in red print(f"\033[91mTook {v}s\033[0m for command: '{k}'") else: # Print in default color print(f"Took {v} for command: '{k}'") - def add(self, cmd: str, time: timedelta) -> None: + def add(self, cmd: str, time: float) -> None: if cmd in self.table: self.table[cmd] += time else: @@ -112,30 +138,33 @@ def run( if cwd is None: cwd = Path.cwd() if input: - glog.debug( + logger.debug( f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}""" ) else: - glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}") - tstart = datetime.datetime.now(tz=datetime.UTC) + logger.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}") + start = timeit.default_timer() # Start the subprocess - with subprocess.Popen( - cmd, - cwd=str(cwd), - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) as process: + with ( + subprocess.Popen( + cmd, + cwd=str(cwd), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + start_new_session=True, + ) as process, + terminate_process_group(process), + ): stdout_buf, stderr_buf = handle_output(process, log) if input: process.communicate(input) - tend = datetime.datetime.now(tz=datetime.UTC) global TIME_TABLE if TIME_TABLE: - TIME_TABLE.add(shlex.join(cmd), tend - tstart) + TIME_TABLE.add(shlex.join(cmd), start - timeit.default_timer()) # Wait for the subprocess to finish cmd_out = CmdOut( diff --git a/pkgs/clan-cli/clan_cli/facts/upload.py b/pkgs/clan-cli/clan_cli/facts/upload.py index 1e3a35289..fc687c988 100644 --- a/pkgs/clan-cli/clan_cli/facts/upload.py +++ b/pkgs/clan-cli/clan_cli/facts/upload.py @@ -37,7 +37,7 @@ def upload_secrets(machine: Machine) -> None: "--delete", "--chmod=D700,F600", f"{tempdir!s}/", - f"{host.target}:{machine.secrets_upload_directory}/", + f"{host.target_for_rsync}:{machine.secrets_upload_directory}/", ], ), log=Log.BOTH, diff --git a/pkgs/clan-cli/clan_cli/ssh/__init__.py b/pkgs/clan-cli/clan_cli/ssh/__init__.py index 27d0a8c47..8c47194b4 100644 --- a/pkgs/clan-cli/clan_cli/ssh/__init__.py +++ b/pkgs/clan-cli/clan_cli/ssh/__init__.py @@ -18,6 +18,7 @@ from shlex import quote from threading import Thread from typing import IO, Any, Generic, TypeVar +from clan_cli.cmd import terminate_process_group from clan_cli.errors import ClanError # https://no-color.org @@ -218,6 +219,13 @@ class Host: def target(self) -> str: return f"{self.user or 'root'}@{self.host}" + @property + def target_for_rsync(self) -> str: + host = self.host + if ":" in host: + host = f"[{host}]" + return f"{self.user or 'root'}@{host}" + def _prefix_output( self, displayed_cmd: str, @@ -287,7 +295,7 @@ class Host: elapsed = now - start if now - last_output > NO_OUTPUT_TIMEOUT: elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed)) - cmdlog.warn( + cmdlog.warning( f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)", extra={"command_prefix": self.command_prefix}, ) @@ -359,7 +367,9 @@ class Host: stderr=stderr_write, env=env, cwd=cwd, + start_new_session=True, ) as p: + stack.enter_context(terminate_process_group(p)) if write_std_fd is not None: write_std_fd.close() if write_err_fd is not None: @@ -380,11 +390,7 @@ class Host: stderr_read, timeout, ) - try: - ret = p.wait(timeout=max(0, timeout - (time.time() - start))) - except subprocess.TimeoutExpired: - p.kill() - raise + ret = p.wait(timeout=max(0, timeout - (time.time() - start))) if ret != 0: if check: raise subprocess.CalledProcessError( @@ -845,6 +851,10 @@ def parse_deployment_address( meta = {} parts = host.split("@") user: str | None = None + # count the number of : in the hostname + if host.count(":") > 1 and not host.startswith("["): + msg = f"Invalid hostname: {host}. IPv6 addresses must be enclosed in brackets , e.g. [::1]" + raise ClanError(msg) if len(parts) > 1: user = parts[0] hostname = parts[1] diff --git a/pkgs/clan-cli/clan_cli/vars/upload.py b/pkgs/clan-cli/clan_cli/vars/upload.py index 3047aad15..f735db942 100644 --- a/pkgs/clan-cli/clan_cli/vars/upload.py +++ b/pkgs/clan-cli/clan_cli/vars/upload.py @@ -38,7 +38,7 @@ def upload_secrets(machine: Machine) -> None: "--delete", "--chmod=D700,F600", f"{tempdir!s}/", - f"{host.user}@{host.host}:{machine.secrets_upload_directory}/", + f"{host.target_for_rsync}:{machine.secrets_upload_directory}/", ], ), log=Log.BOTH, diff --git a/pkgs/clan-cli/tests/test_ssh_remote.py b/pkgs/clan-cli/tests/test_ssh_remote.py index 04dd2e3fd..3516c1476 100644 --- a/pkgs/clan-cli/tests/test_ssh_remote.py +++ b/pkgs/clan-cli/tests/test_ssh_remote.py @@ -1,5 +1,7 @@ import subprocess +import pytest +from clan_cli.errors import ClanError from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address @@ -11,6 +13,10 @@ def test_parse_ipv6() -> None: assert host.host == "fe80::1%eth0" assert host.port is None + with pytest.raises(ClanError): + # We instruct the user to use brackets for IPv6 addresses + host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT) + def test_run(host_group: HostGroup) -> None: proc = host_group.run("echo hello", stdout=subprocess.PIPE)