Merge pull request 'properly support verbatim ipv6 addresses' (#2242) from ipv6-ftw into main

This commit is contained in:
clan-bot
2024-10-10 16:06:06 +00:00
6 changed files with 72 additions and 30 deletions

View File

@@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]:
str(home),
]
with subprocess.Popen(
cmd,
preexec_fn=os.setsid,
) as p:
with subprocess.Popen(cmd, start_new_session=True) as p:
process_group = os.getpgid(p.pid)
try:
print(

View File

@@ -1,20 +1,25 @@
import datetime
import contextlib
import logging
import os
import select
import shlex
import signal
import subprocess
import sys
import timeit
import weakref
from datetime import timedelta
from collections.abc import Iterator
from contextlib import contextmanager
from enum import Enum
from pathlib import Path
from typing import IO, Any
from clan_cli.errors import ClanError
from .custom_logger import get_caller
from .errors import ClanCmdError, CmdOut
glog = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class Log(Enum):
@@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
@contextmanager
def terminate_process_group(process: subprocess.Popen) -> Iterator[None]:
process_group = os.getpgid(process.pid)
if process_group == os.getpgid(os.getpid()):
msg = "Bug! Refusing to terminate the current process group"
raise ClanError(msg)
try:
yield
finally:
try:
os.killpg(process_group, signal.SIGTERM)
try:
with contextlib.suppress(subprocess.TimeoutExpired):
# give the process time to terminate
process.wait(3)
finally:
os.killpg(process_group, signal.SIGKILL)
except ProcessLookupError: # process already terminated
pass
class TimeTable:
"""
This class is used to store the time taken by each command
@@ -67,7 +93,7 @@ class TimeTable:
"""
def __init__(self) -> None:
self.table: dict[str, timedelta] = {}
self.table: dict[str, float] = {}
weakref.finalize(self, self.table_print)
def table_print(self) -> None:
@@ -80,14 +106,14 @@ class TimeTable:
for k, v in sorted_table:
# Check if timedelta is greater than 1 second
if v.total_seconds() > 1:
if v > 1:
# Print in red
print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
else:
# Print in default color
print(f"Took {v} for command: '{k}'")
def add(self, cmd: str, time: timedelta) -> None:
def add(self, cmd: str, time: float) -> None:
if cmd in self.table:
self.table[cmd] += time
else:
@@ -112,30 +138,33 @@ def run(
if cwd is None:
cwd = Path.cwd()
if input:
glog.debug(
logger.debug(
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
)
else:
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
tstart = datetime.datetime.now(tz=datetime.UTC)
logger.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
start = timeit.default_timer()
# Start the subprocess
with subprocess.Popen(
with (
subprocess.Popen(
cmd,
cwd=str(cwd),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as process:
start_new_session=True,
) as process,
terminate_process_group(process),
):
stdout_buf, stderr_buf = handle_output(process, log)
if input:
process.communicate(input)
tend = datetime.datetime.now(tz=datetime.UTC)
global TIME_TABLE
if TIME_TABLE:
TIME_TABLE.add(shlex.join(cmd), tend - tstart)
TIME_TABLE.add(shlex.join(cmd), start - timeit.default_timer())
# Wait for the subprocess to finish
cmd_out = CmdOut(

View File

@@ -37,7 +37,7 @@ def upload_secrets(machine: Machine) -> None:
"--delete",
"--chmod=D700,F600",
f"{tempdir!s}/",
f"{host.target}:{machine.secrets_upload_directory}/",
f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
],
),
log=Log.BOTH,

View File

@@ -18,6 +18,7 @@ from shlex import quote
from threading import Thread
from typing import IO, Any, Generic, TypeVar
from clan_cli.cmd import terminate_process_group
from clan_cli.errors import ClanError
# https://no-color.org
@@ -218,6 +219,13 @@ class Host:
def target(self) -> str:
return f"{self.user or 'root'}@{self.host}"
@property
def target_for_rsync(self) -> str:
host = self.host
if ":" in host:
host = f"[{host}]"
return f"{self.user or 'root'}@{host}"
def _prefix_output(
self,
displayed_cmd: str,
@@ -287,7 +295,7 @@ class Host:
elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warn(
cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix},
)
@@ -359,7 +367,9 @@ class Host:
stderr=stderr_write,
env=env,
cwd=cwd,
start_new_session=True,
) as p:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None:
write_std_fd.close()
if write_err_fd is not None:
@@ -380,11 +390,7 @@ class Host:
stderr_read,
timeout,
)
try:
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
except subprocess.TimeoutExpired:
p.kill()
raise
if ret != 0:
if check:
raise subprocess.CalledProcessError(
@@ -845,6 +851,10 @@ def parse_deployment_address(
meta = {}
parts = host.split("@")
user: str | None = None
# count the number of : in the hostname
if host.count(":") > 1 and not host.startswith("["):
msg = f"Invalid hostname: {host}. IPv6 addresses must be enclosed in brackets , e.g. [::1]"
raise ClanError(msg)
if len(parts) > 1:
user = parts[0]
hostname = parts[1]

View File

@@ -38,7 +38,7 @@ def upload_secrets(machine: Machine) -> None:
"--delete",
"--chmod=D700,F600",
f"{tempdir!s}/",
f"{host.user}@{host.host}:{machine.secrets_upload_directory}/",
f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
],
),
log=Log.BOTH,

View File

@@ -1,5 +1,7 @@
import subprocess
import pytest
from clan_cli.errors import ClanError
from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
@@ -11,6 +13,10 @@ def test_parse_ipv6() -> None:
assert host.host == "fe80::1%eth0"
assert host.port is None
with pytest.raises(ClanError):
# We instruct the user to use brackets for IPv6 addresses
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
def test_run(host_group: HostGroup) -> None:
proc = host_group.run("echo hello", stdout=subprocess.PIPE)