Merge pull request 'properly support verbatim ipv6 addresses' (#2242) from ipv6-ftw into main
This commit is contained in:
@@ -122,10 +122,7 @@ def zerotier_controller() -> Iterator[ZerotierController]:
|
|||||||
str(home),
|
str(home),
|
||||||
]
|
]
|
||||||
|
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(cmd, start_new_session=True) as p:
|
||||||
cmd,
|
|
||||||
preexec_fn=os.setsid,
|
|
||||||
) as p:
|
|
||||||
process_group = os.getpgid(p.pid)
|
process_group = os.getpgid(p.pid)
|
||||||
try:
|
try:
|
||||||
print(
|
print(
|
||||||
|
|||||||
@@ -1,20 +1,25 @@
|
|||||||
import datetime
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import shlex
|
import shlex
|
||||||
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import timeit
|
||||||
import weakref
|
import weakref
|
||||||
from datetime import timedelta
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Any
|
from typing import IO, Any
|
||||||
|
|
||||||
|
from clan_cli.errors import ClanError
|
||||||
|
|
||||||
from .custom_logger import get_caller
|
from .custom_logger import get_caller
|
||||||
from .errors import ClanCmdError, CmdOut
|
from .errors import ClanCmdError, CmdOut
|
||||||
|
|
||||||
glog = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Log(Enum):
|
class Log(Enum):
|
||||||
@@ -60,6 +65,27 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
|
|||||||
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
|
return stdout_buf.decode("utf-8", "replace"), stderr_buf.decode("utf-8", "replace")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def terminate_process_group(process: subprocess.Popen) -> Iterator[None]:
|
||||||
|
process_group = os.getpgid(process.pid)
|
||||||
|
if process_group == os.getpgid(os.getpid()):
|
||||||
|
msg = "Bug! Refusing to terminate the current process group"
|
||||||
|
raise ClanError(msg)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.killpg(process_group, signal.SIGTERM)
|
||||||
|
try:
|
||||||
|
with contextlib.suppress(subprocess.TimeoutExpired):
|
||||||
|
# give the process time to terminate
|
||||||
|
process.wait(3)
|
||||||
|
finally:
|
||||||
|
os.killpg(process_group, signal.SIGKILL)
|
||||||
|
except ProcessLookupError: # process already terminated
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TimeTable:
|
class TimeTable:
|
||||||
"""
|
"""
|
||||||
This class is used to store the time taken by each command
|
This class is used to store the time taken by each command
|
||||||
@@ -67,7 +93,7 @@ class TimeTable:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.table: dict[str, timedelta] = {}
|
self.table: dict[str, float] = {}
|
||||||
weakref.finalize(self, self.table_print)
|
weakref.finalize(self, self.table_print)
|
||||||
|
|
||||||
def table_print(self) -> None:
|
def table_print(self) -> None:
|
||||||
@@ -80,14 +106,14 @@ class TimeTable:
|
|||||||
|
|
||||||
for k, v in sorted_table:
|
for k, v in sorted_table:
|
||||||
# Check if timedelta is greater than 1 second
|
# Check if timedelta is greater than 1 second
|
||||||
if v.total_seconds() > 1:
|
if v > 1:
|
||||||
# Print in red
|
# Print in red
|
||||||
print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
|
print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
|
||||||
else:
|
else:
|
||||||
# Print in default color
|
# Print in default color
|
||||||
print(f"Took {v} for command: '{k}'")
|
print(f"Took {v} for command: '{k}'")
|
||||||
|
|
||||||
def add(self, cmd: str, time: timedelta) -> None:
|
def add(self, cmd: str, time: float) -> None:
|
||||||
if cmd in self.table:
|
if cmd in self.table:
|
||||||
self.table[cmd] += time
|
self.table[cmd] += time
|
||||||
else:
|
else:
|
||||||
@@ -112,30 +138,33 @@ def run(
|
|||||||
if cwd is None:
|
if cwd is None:
|
||||||
cwd = Path.cwd()
|
cwd = Path.cwd()
|
||||||
if input:
|
if input:
|
||||||
glog.debug(
|
logger.debug(
|
||||||
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
|
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
|
logger.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
|
||||||
tstart = datetime.datetime.now(tz=datetime.UTC)
|
start = timeit.default_timer()
|
||||||
|
|
||||||
# Start the subprocess
|
# Start the subprocess
|
||||||
with subprocess.Popen(
|
with (
|
||||||
cmd,
|
subprocess.Popen(
|
||||||
cwd=str(cwd),
|
cmd,
|
||||||
env=env,
|
cwd=str(cwd),
|
||||||
stdout=subprocess.PIPE,
|
env=env,
|
||||||
stderr=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
) as process:
|
stderr=subprocess.PIPE,
|
||||||
|
start_new_session=True,
|
||||||
|
) as process,
|
||||||
|
terminate_process_group(process),
|
||||||
|
):
|
||||||
stdout_buf, stderr_buf = handle_output(process, log)
|
stdout_buf, stderr_buf = handle_output(process, log)
|
||||||
|
|
||||||
if input:
|
if input:
|
||||||
process.communicate(input)
|
process.communicate(input)
|
||||||
tend = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
|
|
||||||
global TIME_TABLE
|
global TIME_TABLE
|
||||||
if TIME_TABLE:
|
if TIME_TABLE:
|
||||||
TIME_TABLE.add(shlex.join(cmd), tend - tstart)
|
TIME_TABLE.add(shlex.join(cmd), start - timeit.default_timer())
|
||||||
|
|
||||||
# Wait for the subprocess to finish
|
# Wait for the subprocess to finish
|
||||||
cmd_out = CmdOut(
|
cmd_out = CmdOut(
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def upload_secrets(machine: Machine) -> None:
|
|||||||
"--delete",
|
"--delete",
|
||||||
"--chmod=D700,F600",
|
"--chmod=D700,F600",
|
||||||
f"{tempdir!s}/",
|
f"{tempdir!s}/",
|
||||||
f"{host.target}:{machine.secrets_upload_directory}/",
|
f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
log=Log.BOTH,
|
log=Log.BOTH,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from shlex import quote
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import IO, Any, Generic, TypeVar
|
from typing import IO, Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from clan_cli.cmd import terminate_process_group
|
||||||
from clan_cli.errors import ClanError
|
from clan_cli.errors import ClanError
|
||||||
|
|
||||||
# https://no-color.org
|
# https://no-color.org
|
||||||
@@ -218,6 +219,13 @@ class Host:
|
|||||||
def target(self) -> str:
|
def target(self) -> str:
|
||||||
return f"{self.user or 'root'}@{self.host}"
|
return f"{self.user or 'root'}@{self.host}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_for_rsync(self) -> str:
|
||||||
|
host = self.host
|
||||||
|
if ":" in host:
|
||||||
|
host = f"[{host}]"
|
||||||
|
return f"{self.user or 'root'}@{host}"
|
||||||
|
|
||||||
def _prefix_output(
|
def _prefix_output(
|
||||||
self,
|
self,
|
||||||
displayed_cmd: str,
|
displayed_cmd: str,
|
||||||
@@ -287,7 +295,7 @@ class Host:
|
|||||||
elapsed = now - start
|
elapsed = now - start
|
||||||
if now - last_output > NO_OUTPUT_TIMEOUT:
|
if now - last_output > NO_OUTPUT_TIMEOUT:
|
||||||
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
|
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
|
||||||
cmdlog.warn(
|
cmdlog.warning(
|
||||||
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
|
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
|
||||||
extra={"command_prefix": self.command_prefix},
|
extra={"command_prefix": self.command_prefix},
|
||||||
)
|
)
|
||||||
@@ -359,7 +367,9 @@ class Host:
|
|||||||
stderr=stderr_write,
|
stderr=stderr_write,
|
||||||
env=env,
|
env=env,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
|
start_new_session=True,
|
||||||
) as p:
|
) as p:
|
||||||
|
stack.enter_context(terminate_process_group(p))
|
||||||
if write_std_fd is not None:
|
if write_std_fd is not None:
|
||||||
write_std_fd.close()
|
write_std_fd.close()
|
||||||
if write_err_fd is not None:
|
if write_err_fd is not None:
|
||||||
@@ -380,11 +390,7 @@ class Host:
|
|||||||
stderr_read,
|
stderr_read,
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
try:
|
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
|
||||||
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
p.kill()
|
|
||||||
raise
|
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
if check:
|
if check:
|
||||||
raise subprocess.CalledProcessError(
|
raise subprocess.CalledProcessError(
|
||||||
@@ -845,6 +851,10 @@ def parse_deployment_address(
|
|||||||
meta = {}
|
meta = {}
|
||||||
parts = host.split("@")
|
parts = host.split("@")
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
# count the number of : in the hostname
|
||||||
|
if host.count(":") > 1 and not host.startswith("["):
|
||||||
|
msg = f"Invalid hostname: {host}. IPv6 addresses must be enclosed in brackets , e.g. [::1]"
|
||||||
|
raise ClanError(msg)
|
||||||
if len(parts) > 1:
|
if len(parts) > 1:
|
||||||
user = parts[0]
|
user = parts[0]
|
||||||
hostname = parts[1]
|
hostname = parts[1]
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def upload_secrets(machine: Machine) -> None:
|
|||||||
"--delete",
|
"--delete",
|
||||||
"--chmod=D700,F600",
|
"--chmod=D700,F600",
|
||||||
f"{tempdir!s}/",
|
f"{tempdir!s}/",
|
||||||
f"{host.user}@{host.host}:{machine.secrets_upload_directory}/",
|
f"{host.target_for_rsync}:{machine.secrets_upload_directory}/",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
log=Log.BOTH,
|
log=Log.BOTH,
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from clan_cli.errors import ClanError
|
||||||
from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
|
from clan_cli.ssh import Host, HostGroup, HostKeyCheck, parse_deployment_address
|
||||||
|
|
||||||
|
|
||||||
@@ -11,6 +13,10 @@ def test_parse_ipv6() -> None:
|
|||||||
assert host.host == "fe80::1%eth0"
|
assert host.host == "fe80::1%eth0"
|
||||||
assert host.port is None
|
assert host.port is None
|
||||||
|
|
||||||
|
with pytest.raises(ClanError):
|
||||||
|
# We instruct the user to use brackets for IPv6 addresses
|
||||||
|
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
|
||||||
|
|
||||||
|
|
||||||
def test_run(host_group: HostGroup) -> None:
|
def test_run(host_group: HostGroup) -> None:
|
||||||
proc = host_group.run("echo hello", stdout=subprocess.PIPE)
|
proc = host_group.run("echo hello", stdout=subprocess.PIPE)
|
||||||
|
|||||||
Reference in New Issue
Block a user