Merge pull request 'properly support verbatim ipv6 addresses' (#2242) from ipv6-ftw into main

This commit is contained in:
clan-bot
2024-10-10 16:06:06 +00:00
6 changed files with 72 additions and 30 deletions

View File

@@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]:
str(home), str(home),
] ]
with subprocess.Popen( with subprocess.Popen(cmd, start_new_session=True) as p:
cmd,
preexec_fn=os.setsid,
) as p:
process_group = os.getpgid(p.pid) process_group = os.getpgid(p.pid)
try: try:
print( print(

View File

@@ -1,20 +1,25 @@
import datetime import contextlib
import logging import logging
import os import os
import select import select
import shlex import shlex
import signal
import subprocess import subprocess
import sys import sys
import timeit
import weakref import weakref
from datetime import timedelta from collections.abc import Iterator
from contextlib import contextmanager
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import IO, Any from typing import IO, Any
from clan_cli.errors import ClanError
from .custom_logger import get_caller from .custom_logger import get_caller
from .errors import ClanCmdError, CmdOut from .errors import ClanCmdError, CmdOut
glog = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Log(Enum): class Log(Enum):
@@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace") return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
@contextmanager
def terminate_process_group(process: subprocess.Popen) -> Iterator[None]:
process_group = os.getpgid(process.pid)
if process_group == os.getpgid(os.getpid()):
msg = "Bug! Refusing to terminate the current process group"
raise ClanError(msg)
try:
yield
finally:
try:
os.killpg(process_group, signal.SIGTERM)
try:
with contextlib.suppress(subprocess.TimeoutExpired):
# give the process time to terminate
process.wait(3)
finally:
os.killpg(process_group, signal.SIGKILL)
except ProcessLookupError: # process already terminated
pass
class TimeTable: class TimeTable:
""" """
This class is used to store the time taken by each command This class is used to store the time taken by each command
@@ -67,7 +93,7 @@ class TimeTable:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.table: dict[str, timedelta] = {} self.table: dict[str, float] = {}
weakref.finalize(self, self.table_print) weakref.finalize(self, self.table_print)
def table_print(self) -> None: def table_print(self) -> None:
@@ -80,14 +106,14 @@ class TimeTable:
for k, v in sorted_table: for k, v in sorted_table:
# Check if timedelta is greater than 1 second # Check if timedelta is greater than 1 second
if v.total_seconds() > 1: if v > 1:
# Print in red # Print in red
print(f"\033[91mTook {v}s\033[0m for command: '{k}'") print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
else: else:
# Print in default color # Print in default color
print(f"Took {v} for command: '{k}'") print(f"Took {v} for command: '{k}'")
def add(self, cmd: str, time: timedelta) -> None: def add(self, cmd: str, time: float) -> None:
if cmd in self.table: if cmd in self.table:
self.table[cmd] += time self.table[cmd] += time
else: else:
@@ -112,30 +138,33 @@ def run(
if cwd is None: if cwd is None:
cwd = Path.cwd() cwd = Path.cwd()
if input: if input:
glog.debug( logger.debug(
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}""" f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
) )
else: else:
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}") logger.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
tstart = datetime.datetime.now(tz=datetime.UTC) start = timeit.default_timer()
# Start the subprocess # Start the subprocess
with subprocess.Popen( with (
cmd, subprocess.Popen(
cwd=str(cwd), cmd,
env=env, cwd=str(cwd),
stdout=subprocess.PIPE, env=env,
stderr=subprocess.PIPE, stdout=subprocess.PIPE,
) as process: stderr=subprocess.PIPE,
start_new_session=True,
) as process,
terminate_process_group(process),
):
stdout_buf, stderr_buf = handle_output(process, log) stdout_buf, stderr_buf = handle_output(process, log)
if input: if input:
process.communicate(input) process.communicate(input)
tend = datetime.datetime.now(tz=datetime.UTC)
global TIME_TABLE global TIME_TABLE
if TIME_TABLE: if TIME_TABLE:
TIME_TABLE.add(shlex.join(cmd), tend - tstart) TIME_TABLE.add(shlex.join(cmd), start - timeit.default_timer())
# Wait for the subprocess to finish # Wait for the subprocess to finish
cmd_out = CmdOut( cmd_out = CmdOut(

View File

@@ -37,7 +37,7 @@ def upload_secrets(machine: Machine) -> None:
"--delete", "--delete",
"--chmod=D700,F600", "--chmod=D700,F600",
f"{tempdir!s}/", f"{tempdir!s}/",
f"{host.target}:{machine.secrets_upload_directory}/", f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
], ],
), ),
log=Log.BOTH, log=Log.BOTH,

View File

@@ -18,6 +18,7 @@ from shlex import quote
from threading import Thread from threading import Thread
from typing import IO, Any, Generic, TypeVar from typing import IO, Any, Generic, TypeVar
from clan_cli.cmd import terminate_process_group
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
# https://no-color.org # https://no-color.org
@@ -218,6 +219,13 @@ class Host:
def target(self) -> str: def target(self) -> str:
return f"{self.user or 'root'}@{self.host}" return f"{self.user or 'root'}@{self.host}"
@property
def target_for_rsync(self) -> str:
host = self.host
if ":" in host:
host = f"[{host}]"
return f"{self.user or 'root'}@{host}"
def _prefix_output( def _prefix_output(
self, self,
displayed_cmd: str, displayed_cmd: str,
@@ -287,7 +295,7 @@ class Host:
elapsed = now - start elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT: if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed)) elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warn( cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)", f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix}, extra={"command_prefix": self.command_prefix},
) )
@@ -359,7 +367,9 @@ class Host:
stderr=stderr_write, stderr=stderr_write,
env=env, env=env,
cwd=cwd, cwd=cwd,
start_new_session=True,
) as p: ) as p:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None: if write_std_fd is not None:
write_std_fd.close() write_std_fd.close()
if write_err_fd is not None: if write_err_fd is not None:
@@ -380,11 +390,7 @@ class Host:
stderr_read, stderr_read,
timeout, timeout,
) )
try: ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
except subprocess.TimeoutExpired:
p.kill()
raise
if ret != 0: if ret != 0:
if check: if check:
raise subprocess.CalledProcessError( raise subprocess.CalledProcessError(
@@ -845,6 +851,10 @@ def parse_deployment_address(
meta = {} meta = {}
parts = host.split("@") parts = host.split("@")
user: str | None = None user: str | None = None
# count the number of : in the hostname
if host.count(":") > 1 and not host.startswith("["):
msg = f"Invalid hostname: {host}. IPv6 addresses must be enclosed in brackets , e.g. [::1]"
raise ClanError(msg)
if len(parts) > 1: if len(parts) > 1:
user = parts[0] user = parts[0]
hostname = parts[1] hostname = parts[1]

View File

@@ -38,7 +38,7 @@ def upload_secrets(machine: Machine) -> None:
"--delete", "--delete",
"--chmod=D700,F600", "--chmod=D700,F600",
f"{tempdir!s}/", f"{tempdir!s}/",
f"{host.user}@{host.host}:{machine.secrets_upload_directory}/", f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
], ],
), ),
log=Log.BOTH, log=Log.BOTH,

View File

@@ -1,5 +1,7 @@
import subprocess import subprocess
import pytest
from clan_cli.errors import ClanError
from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
@@ -11,6 +13,10 @@ def test_parse_ipv6() -> None:
assert host.host == "fe80::1%eth0" assert host.host == "fe80::1%eth0"
assert host.port is None assert host.port is None
with pytest.raises(ClanError):
# We instruct the user to use brackets for IPv6 addresses
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
def test_run(host_group: HostGroup) -> None: def test_run(host_group: HostGroup) -> None:
proc = host_group.run("echo hello", stdout=subprocess.PIPE) proc = host_group.run("echo hello", stdout=subprocess.PIPE)