diff --git a/nixosModules/clanCore/zerotier/generate.py b/nixosModules/clanCore/zerotier/generate.py index 75d969472..1df7ea168 100644 --- a/nixosModules/clanCore/zerotier/generate.py +++ b/nixosModules/clanCore/zerotier/generate.py @@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]: str(home), ] - with subprocess.Popen( - cmd, - preexec_fn=os.setsid, - ) as p: + with subprocess.Popen(cmd, start_new_session=True) as p: process_group = os.getpgid(p.pid) try: print( diff --git a/pkgs/clan-cli/clan_cli/cmd.py b/pkgs/clan-cli/clan_cli/cmd.py index d95e7db0e..38414caeb 100644 --- a/pkgs/clan-cli/clan_cli/cmd.py +++ b/pkgs/clan-cli/clan_cli/cmd.py @@ -1,16 +1,21 @@ -import datetime +import contextlib import logging import os import select import shlex +import signal import subprocess import sys +import timeit import weakref -from datetime import timedelta +from collections.abc import Iterator +from contextlib import contextmanager from enum import Enum from pathlib import Path from typing import IO, Any +from clan_cli.errors import ClanError + from .custom_logger import get_caller from .errors import ClanCmdError, CmdOut @@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]: return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace") +@contextmanager +def terminate_process_group(process: subprocess.Popen) -> Iterator[None]: + process_group = os.getpgid(process.pid) + if process_group == os.getpgid(os.getpid()): + msg = "Bug! Refusing to terminate the current process group" + raise ClanError(msg) + try: + yield + finally: + try: + os.killpg(process_group, signal.SIGTERM) + try: + with contextlib.suppress(subprocess.TimeoutExpired): + # give the process time to terminate + process.wait(3) + finally: + os.killpg(process_group, signal.SIGKILL) + except ProcessLookupError: # process already terminated + pass + + class TimeTable: """ This class is used to store the time taken by each command @@ -120,13 +146,17 @@ def run( tstart = datetime.datetime.now(tz=datetime.UTC) # Start the subprocess - with subprocess.Popen( - cmd, - cwd=str(cwd), - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) as process: + with ( + subprocess.Popen( + cmd, + cwd=str(cwd), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + start_new_session=True, + ) as process, + terminate_process_group(process), + ): stdout_buf, stderr_buf = handle_output(process, log) if input: diff --git a/pkgs/clan-cli/clan_cli/ssh/__init__.py b/pkgs/clan-cli/clan_cli/ssh/__init__.py index e66a54d27..8c47194b4 100644 --- a/pkgs/clan-cli/clan_cli/ssh/__init__.py +++ b/pkgs/clan-cli/clan_cli/ssh/__init__.py @@ -18,6 +18,7 @@ from shlex import quote from threading import Thread from typing import IO, Any, Generic, TypeVar +from clan_cli.cmd import terminate_process_group from clan_cli.errors import ClanError # https://no-color.org @@ -294,7 +295,7 @@ class Host: elapsed = now - start if now - last_output > NO_OUTPUT_TIMEOUT: elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed)) - cmdlog.warn( + cmdlog.warning( f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)", extra={"command_prefix": self.command_prefix}, ) @@ -366,7 +367,9 @@ class Host: stderr=stderr_write, env=env, cwd=cwd, + start_new_session=True, ) as p: + stack.enter_context(terminate_process_group(p)) if write_std_fd is not None: write_std_fd.close() if write_err_fd is not None: @@ -387,11 +390,7 @@ class Host: stderr_read, timeout, ) - try: - ret = p.wait(timeout=max(0, timeout - (time.time() - start))) - except subprocess.TimeoutExpired: - p.kill() - raise + ret = p.wait(timeout=max(0, timeout - (time.time() - start))) if ret != 0: if check: raise subprocess.CalledProcessError(