improve terminating processes on error

This commit is contained in:
Jörg Thalheim
2024-10-10 17:24:25 +02:00
parent d97bda9c0d
commit 71e7ecd49c
3 changed files with 45 additions and 19 deletions

View File

@@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]:
str(home), str(home),
] ]
with subprocess.Popen( with subprocess.Popen(cmd, start_new_session=True) as p:
cmd,
preexec_fn=os.setsid,
) as p:
process_group = os.getpgid(p.pid) process_group = os.getpgid(p.pid)
try: try:
print( print(

View File

@@ -1,16 +1,21 @@
import datetime import contextlib
import logging import logging
import os import os
import select import select
import shlex import shlex
import signal
import subprocess import subprocess
import sys import sys
import timeit
import weakref import weakref
from datetime import timedelta from collections.abc import Iterator
from contextlib import contextmanager
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import IO, Any from typing import IO, Any
from clan_cli.errors import ClanError
from .custom_logger import get_caller from .custom_logger import get_caller
from .errors import ClanCmdError, CmdOut from .errors import ClanCmdError, CmdOut
@@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace") return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
@contextmanager
def terminate_process_group(process: subprocess.Popen) -> Iterator[None]:
process_group = os.getpgid(process.pid)
if process_group == os.getpgid(os.getpid()):
msg = "Bug! Refusing to terminate the current process group"
raise ClanError(msg)
try:
yield
finally:
try:
os.killpg(process_group, signal.SIGTERM)
try:
with contextlib.suppress(subprocess.TimeoutExpired):
# give the process time to terminate
process.wait(3)
finally:
os.killpg(process_group, signal.SIGKILL)
except ProcessLookupError: # process already terminated
pass
class TimeTable: class TimeTable:
""" """
This class is used to store the time taken by each command This class is used to store the time taken by each command
@@ -120,13 +146,17 @@ def run(
tstart = datetime.datetime.now(tz=datetime.UTC) tstart = datetime.datetime.now(tz=datetime.UTC)
# Start the subprocess # Start the subprocess
with subprocess.Popen( with (
cmd, subprocess.Popen(
cwd=str(cwd), cmd,
env=env, cwd=str(cwd),
stdout=subprocess.PIPE, env=env,
stderr=subprocess.PIPE, stdout=subprocess.PIPE,
) as process: stderr=subprocess.PIPE,
start_new_session=True,
) as process,
terminate_process_group(process),
):
stdout_buf, stderr_buf = handle_output(process, log) stdout_buf, stderr_buf = handle_output(process, log)
if input: if input:

View File

@@ -18,6 +18,7 @@ from shlex import quote
from threading import Thread from threading import Thread
from typing import IO, Any, Generic, TypeVar from typing import IO, Any, Generic, TypeVar
from clan_cli.cmd import terminate_process_group
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
# https://no-color.org # https://no-color.org
@@ -294,7 +295,7 @@ class Host:
elapsed = now - start elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT: if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed)) elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warn( cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)", f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix}, extra={"command_prefix": self.command_prefix},
) )
@@ -366,7 +367,9 @@ class Host:
stderr=stderr_write, stderr=stderr_write,
env=env, env=env,
cwd=cwd, cwd=cwd,
start_new_session=True,
) as p: ) as p:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None: if write_std_fd is not None:
write_std_fd.close() write_std_fd.close()
if write_err_fd is not None: if write_err_fd is not None:
@@ -387,11 +390,7 @@ class Host:
stderr_read, stderr_read,
timeout, timeout,
) )
try: ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
except subprocess.TimeoutExpired:
p.kill()
raise
if ret != 0: if ret != 0:
if check: if check:
raise subprocess.CalledProcessError( raise subprocess.CalledProcessError(