Merge pull request 'properly support verbatim ipv6 addresses' (#2242) from ipv6-ftw into main
This commit is contained in:
@@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]:
|
||||
str(home),
|
||||
]
|
||||
|
||||
with subprocess.Popen(
|
||||
cmd,
|
||||
preexec_fn=os.setsid,
|
||||
) as p:
|
||||
with subprocess.Popen(cmd, start_new_session=True) as p:
|
||||
process_group = os.getpgid(p.pid)
|
||||
try:
|
||||
print(
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
import datetime
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import timeit
|
||||
import weakref
|
||||
from datetime import timedelta
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
|
||||
from clan_cli.errors import ClanError
|
||||
|
||||
from .custom_logger import get_caller
|
||||
from .errors import ClanCmdError, CmdOut
|
||||
|
||||
glog = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Log(Enum):
|
||||
@@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
|
||||
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def terminate_process_group(process: subprocess.Popen) -> Iterator[None]:
|
||||
process_group = os.getpgid(process.pid)
|
||||
if process_group == os.getpgid(os.getpid()):
|
||||
msg = "Bug! Refusing to terminate the current process group"
|
||||
raise ClanError(msg)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
try:
|
||||
os.killpg(process_group, signal.SIGTERM)
|
||||
try:
|
||||
with contextlib.suppress(subprocess.TimeoutExpired):
|
||||
# give the process time to terminate
|
||||
process.wait(3)
|
||||
finally:
|
||||
os.killpg(process_group, signal.SIGKILL)
|
||||
except ProcessLookupError: # process already terminated
|
||||
pass
|
||||
|
||||
|
||||
class TimeTable:
|
||||
"""
|
||||
This class is used to store the time taken by each command
|
||||
@@ -67,7 +93,7 @@ class TimeTable:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.table: dict[str, timedelta] = {}
|
||||
self.table: dict[str, float] = {}
|
||||
weakref.finalize(self, self.table_print)
|
||||
|
||||
def table_print(self) -> None:
|
||||
@@ -80,14 +106,14 @@ class TimeTable:
|
||||
|
||||
for k, v in sorted_table:
|
||||
# Check if timedelta is greater than 1 second
|
||||
if v.total_seconds() > 1:
|
||||
if v > 1:
|
||||
# Print in red
|
||||
print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
|
||||
else:
|
||||
# Print in default color
|
||||
print(f"Took {v} for command: '{k}'")
|
||||
|
||||
def add(self, cmd: str, time: timedelta) -> None:
|
||||
def add(self, cmd: str, time: float) -> None:
|
||||
if cmd in self.table:
|
||||
self.table[cmd] += time
|
||||
else:
|
||||
@@ -112,30 +138,33 @@ def run(
|
||||
if cwd is None:
|
||||
cwd = Path.cwd()
|
||||
if input:
|
||||
glog.debug(
|
||||
logger.debug(
|
||||
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
|
||||
)
|
||||
else:
|
||||
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
|
||||
tstart = datetime.datetime.now(tz=datetime.UTC)
|
||||
logger.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
|
||||
start = timeit.default_timer()
|
||||
|
||||
# Start the subprocess
|
||||
with subprocess.Popen(
|
||||
with (
|
||||
subprocess.Popen(
|
||||
cmd,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as process:
|
||||
start_new_session=True,
|
||||
) as process,
|
||||
terminate_process_group(process),
|
||||
):
|
||||
stdout_buf, stderr_buf = handle_output(process, log)
|
||||
|
||||
if input:
|
||||
process.communicate(input)
|
||||
tend = datetime.datetime.now(tz=datetime.UTC)
|
||||
|
||||
global TIME_TABLE
|
||||
if TIME_TABLE:
|
||||
TIME_TABLE.add(shlex.join(cmd), tend - tstart)
|
||||
TIME_TABLE.add(shlex.join(cmd), start - timeit.default_timer())
|
||||
|
||||
# Wait for the subprocess to finish
|
||||
cmd_out = CmdOut(
|
||||
|
||||
@@ -37,7 +37,7 @@ def upload_secrets(machine: Machine) -> None:
|
||||
"--delete",
|
||||
"--chmod=D700,F600",
|
||||
f"{tempdir!s}/",
|
||||
f"{host.target}:{machine.secrets_upload_directory}/",
|
||||
f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
|
||||
],
|
||||
),
|
||||
log=Log.BOTH,
|
||||
|
||||
@@ -18,6 +18,7 @@ from shlex import quote
|
||||
from threading import Thread
|
||||
from typing import IO, Any, Generic, TypeVar
|
||||
|
||||
from clan_cli.cmd import terminate_process_group
|
||||
from clan_cli.errors import ClanError
|
||||
|
||||
# https://no-color.org
|
||||
@@ -218,6 +219,13 @@ class Host:
|
||||
def target(self) -> str:
|
||||
return f"{self.user or 'root'}@{self.host}"
|
||||
|
||||
@property
|
||||
def target_for_rsync(self) -> str:
|
||||
host = self.host
|
||||
if ":" in host:
|
||||
host = f"[{host}]"
|
||||
return f"{self.user or 'root'}@{host}"
|
||||
|
||||
def _prefix_output(
|
||||
self,
|
||||
displayed_cmd: str,
|
||||
@@ -287,7 +295,7 @@ class Host:
|
||||
elapsed = now - start
|
||||
if now - last_output > NO_OUTPUT_TIMEOUT:
|
||||
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
|
||||
cmdlog.warn(
|
||||
cmdlog.warning(
|
||||
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
|
||||
extra={"command_prefix": self.command_prefix},
|
||||
)
|
||||
@@ -359,7 +367,9 @@ class Host:
|
||||
stderr=stderr_write,
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
start_new_session=True,
|
||||
) as p:
|
||||
stack.enter_context(terminate_process_group(p))
|
||||
if write_std_fd is not None:
|
||||
write_std_fd.close()
|
||||
if write_err_fd is not None:
|
||||
@@ -380,11 +390,7 @@ class Host:
|
||||
stderr_read,
|
||||
timeout,
|
||||
)
|
||||
try:
|
||||
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
raise
|
||||
if ret != 0:
|
||||
if check:
|
||||
raise subprocess.CalledProcessError(
|
||||
@@ -845,6 +851,10 @@ def parse_deployment_address(
|
||||
meta = {}
|
||||
parts = host.split("@")
|
||||
user: str | None = None
|
||||
# count the number of : in the hostname
|
||||
if host.count(":") > 1 and not host.startswith("["):
|
||||
msg = f"Invalid hostname: {host}. IPv6 addresses must be enclosed in brackets , e.g. [::1]"
|
||||
raise ClanError(msg)
|
||||
if len(parts) > 1:
|
||||
user = parts[0]
|
||||
hostname = parts[1]
|
||||
|
||||
@@ -38,7 +38,7 @@ def upload_secrets(machine: Machine) -> None:
|
||||
"--delete",
|
||||
"--chmod=D700,F600",
|
||||
f"{tempdir!s}/",
|
||||
f"{host.user}@{host.host}:{machine.secrets_upload_directory}/",
|
||||
f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
|
||||
],
|
||||
),
|
||||
log=Log.BOTH,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
|
||||
|
||||
|
||||
@@ -11,6 +13,10 @@ def test_parse_ipv6() -> None:
|
||||
assert host.host == "fe80::1%eth0"
|
||||
assert host.port is None
|
||||
|
||||
with pytest.raises(ClanError):
|
||||
# We instruct the user to use brackets for IPv6 addresses
|
||||
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
|
||||
|
||||
|
||||
def test_run(host_group: HostGroup) -> None:
|
||||
proc = host_group.run("echo hello", stdout=subprocess.PIPE)
|
||||
|
||||
Reference in New Issue
Block a user