clan-cli: Refactor ssh part 2, Refactor custom_logger

This commit is contained in:
Qubasa
2024-11-22 22:08:50 +01:00
parent 05b31c7195
commit 8866a85765
23 changed files with 713 additions and 1255 deletions

View File

@@ -1,8 +1,8 @@
import argparse import argparse
import json import json
import subprocess
from dataclasses import dataclass from dataclasses import dataclass
from clan_cli.cmd import Log
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
complete_backup_providers_for_machine, complete_backup_providers_for_machine,
@@ -23,7 +23,7 @@ def list_provider(machine: Machine, provider: str) -> list[Backup]:
backup_metadata = json.loads(machine.eval_nix("config.clan.core.backups")) backup_metadata = json.loads(machine.eval_nix("config.clan.core.backups"))
proc = machine.target_host.run( proc = machine.target_host.run(
[backup_metadata["providers"][provider]["list"]], [backup_metadata["providers"][provider]["list"]],
stdout=subprocess.PIPE, log=Log.STDERR,
check=False, check=False,
) )
if proc.returncode != 0: if proc.returncode != 0:

View File

@@ -1,7 +1,7 @@
import argparse import argparse
import json import json
import subprocess
from clan_cli.cmd import Log
from clan_cli.completions import ( from clan_cli.completions import (
add_dynamic_completer, add_dynamic_completer,
complete_backup_providers_for_machine, complete_backup_providers_for_machine,
@@ -28,7 +28,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if pre_restore := backup_folders[service]["preRestoreCommand"]: if pre_restore := backup_folders[service]["preRestoreCommand"]:
proc = machine.target_host.run( proc = machine.target_host.run(
[pre_restore], [pre_restore],
stdout=subprocess.PIPE, log=Log.STDERR,
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:
@@ -37,7 +37,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
proc = machine.target_host.run( proc = machine.target_host.run(
[backup_metadata["providers"][provider]["restore"]], [backup_metadata["providers"][provider]["restore"]],
stdout=subprocess.PIPE, log=Log.STDERR,
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:
@@ -47,7 +47,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if post_restore := backup_folders[service]["postRestoreCommand"]: if post_restore := backup_folders[service]["postRestoreCommand"]:
proc = machine.target_host.run( proc = machine.target_host.run(
[post_restore], [post_restore],
stdout=subprocess.PIPE, log=Log.STDERR,
extra_env=env, extra_env=env,
) )
if proc.returncode != 0: if proc.returncode != 0:

View File

@@ -1,11 +1,12 @@
import contextlib import contextlib
import logging import logging
import math
import os import os
import select import select
import shlex import shlex
import signal import signal
import subprocess import subprocess
import sys import time
import timeit import timeit
import weakref import weakref
from collections.abc import Iterator from collections.abc import Iterator
@@ -14,12 +15,25 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import IO, Any from typing import IO, Any
from clan_cli.errors import ClanError, indent_command from clan_cli.custom_logger import print_trace
from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command
from .custom_logger import get_callers cmdlog = logging.getLogger(__name__)
from .errors import ClanCmdError, CmdOut
logger = logging.getLogger(__name__)
class ClanCmdTimeoutError(ClanError):
timeout: float
def __init__(
self,
msg: str | None = None,
*,
description: str | None = None,
location: str | None = None,
timeout: float,
) -> None:
self.timeout = timeout
super().__init__(msg, description=description, location=location)
class Log(Enum): class Log(Enum):
@@ -30,14 +44,31 @@ class Log(Enum):
def handle_io( def handle_io(
process: subprocess.Popen, input_bytes: bytes | None, log: Log process: subprocess.Popen,
log: Log,
cmdlog: logging.Logger,
prefix: str,
*,
input_bytes: bytes | None,
stdout: IO[bytes] | None,
stderr: IO[bytes] | None,
timeout: float = math.inf,
) -> tuple[str, str]: ) -> tuple[str, str]:
rlist = [process.stdout, process.stderr] rlist = [process.stdout, process.stderr]
wlist = [process.stdin] if input_bytes is not None else [] wlist = [process.stdin] if input_bytes is not None else []
stdout_buf = b"" stdout_buf = b""
stderr_buf = b"" stderr_buf = b""
start = time.time()
# Loop until no more data is available
while len(rlist) != 0 or len(wlist) != 0: while len(rlist) != 0 or len(wlist) != 0:
# Check if the command has timed out
if time.time() - start > timeout:
msg = f"Command timed out after {timeout} seconds"
description = prefix
raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout)
# Wait for data to be available
readlist, writelist, _ = select.select(rlist, wlist, [], 0.1) readlist, writelist, _ = select.select(rlist, wlist, [], 0.1)
if len(readlist) == 0 and len(writelist) == 0: if len(readlist) == 0 and len(writelist) == 0:
if process.poll() is None: if process.poll() is None:
@@ -45,6 +76,7 @@ def handle_io(
# Process has exited # Process has exited
break break
# Function to handle file descriptors
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> bytes: def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> bytes:
if fd and fd in readlist: if fd and fd in readlist:
read = os.read(fd.fileno(), 4096) read = os.read(fd.fileno(), 4096)
@@ -53,19 +85,36 @@ def handle_io(
rlist.remove(fd) rlist.remove(fd)
return b"" return b""
#
# Process stdout
#
ret = handle_fd(process.stdout, readlist) ret = handle_fd(process.stdout, readlist)
if ret and log in [Log.STDOUT, Log.BOTH]: if ret and log in [Log.STDOUT, Log.BOTH]:
sys.stdout.buffer.write(ret) lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
sys.stdout.flush() for line in lines:
cmdlog.info(line, extra={"command_prefix": prefix})
if ret and stdout:
stdout.write(ret)
stdout.flush()
#
# Process stderr
#
stdout_buf += ret stdout_buf += ret
ret = handle_fd(process.stderr, readlist) ret = handle_fd(process.stderr, readlist)
if ret and log in [Log.STDERR, Log.BOTH]: if ret and log in [Log.STDERR, Log.BOTH]:
sys.stderr.buffer.write(ret) lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
sys.stderr.flush() for line in lines:
cmdlog.error(line, extra={"command_prefix": prefix})
if ret and stderr:
stderr.write(ret)
stderr.flush()
stderr_buf += ret stderr_buf += ret
#
# Process stdin
#
if process.stdin in writelist: if process.stdin in writelist:
if input_bytes: if input_bytes:
try: try:
@@ -168,42 +217,35 @@ def run(
cmd: list[str], cmd: list[str],
*, *,
input: bytes | None = None, # noqa: A002 input: bytes | None = None, # noqa: A002
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
env: dict[str, str] | None = None, env: dict[str, str] | None = None,
cwd: Path | None = None, cwd: Path | None = None,
log: Log = Log.STDERR, log: Log = Log.STDERR,
logger: logging.Logger = cmdlog,
prefix: str | None = None,
check: bool = True, check: bool = True,
error_msg: str | None = None, error_msg: str | None = None,
needs_user_terminal: bool = False, needs_user_terminal: bool = False,
timeout: float = math.inf,
shell: bool = False,
) -> CmdOut: ) -> CmdOut:
if cwd is None: if cwd is None:
cwd = Path.cwd() cwd = Path.cwd()
def print_trace(msg: str) -> None: if prefix is None:
trace_depth = int(os.environ.get("TRACE_DEPTH", "0")) prefix = "localhost"
callers = get_callers(3, 4 + trace_depth)
if "run_no_stdout" in callers[0]:
callers = callers[1:]
else:
callers.pop()
if len(callers) == 1:
callers_str = f"Caller: {callers[0]}\n"
else:
callers_str = "\n".join(
f"{i+1}: {caller}" for i, caller in enumerate(callers)
)
callers_str = f"Callers:\n{callers_str}"
logger.debug(f"{msg} \n{callers_str}")
if input: if input:
if any(not ch.isprintable() for ch in input.decode("ascii", "replace")): if any(not ch.isprintable() for ch in input.decode("ascii", "replace")):
filtered_input = "<<binary_blob>>" filtered_input = "<<binary_blob>>"
else: else:
filtered_input = input.decode("ascii", "replace") filtered_input = input.decode("ascii", "replace")
print_trace(f"$: echo '{filtered_input}' | {indent_command(cmd)}") print_trace(
f"$: echo '{filtered_input}' | {indent_command(cmd)}", logger, prefix
)
elif logger.isEnabledFor(logging.DEBUG): elif logger.isEnabledFor(logging.DEBUG):
print_trace(f"$: {indent_command(cmd)}") print_trace(f"$: {indent_command(cmd)}", logger, prefix)
start = timeit.default_timer() start = timeit.default_timer()
with ExitStack() as stack: with ExitStack() as stack:
@@ -217,6 +259,7 @@ def run(
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
start_new_session=not needs_user_terminal, start_new_session=not needs_user_terminal,
shell=shell,
) )
) )
@@ -226,7 +269,16 @@ def run(
else: else:
stack.enter_context(terminate_process_group(process)) stack.enter_context(terminate_process_group(process))
stdout_buf, stderr_buf = handle_io(process, input, log) stdout_buf, stderr_buf = handle_io(
process,
log,
prefix=prefix,
cmdlog=logger,
timeout=timeout,
input_bytes=input,
stdout=stdout,
stderr=stderr,
)
process.wait() process.wait()
global TIME_TABLE global TIME_TABLE
@@ -256,9 +308,12 @@ def run_no_stdout(
env: dict[str, str] | None = None, env: dict[str, str] | None = None,
cwd: Path | None = None, cwd: Path | None = None,
log: Log = Log.STDERR, log: Log = Log.STDERR,
logger: logging.Logger = cmdlog,
prefix: str | None = None,
check: bool = True, check: bool = True,
error_msg: str | None = None, error_msg: str | None = None,
needs_user_terminal: bool = False, needs_user_terminal: bool = False,
shell: bool = False,
) -> CmdOut: ) -> CmdOut:
""" """
Like run, but automatically suppresses stdout, if not in DEBUG log level. Like run, but automatically suppresses stdout, if not in DEBUG log level.
@@ -274,6 +329,8 @@ def run_no_stdout(
env=env, env=env,
log=log, log=log,
check=check, check=check,
prefix=prefix,
error_msg=error_msg, error_msg=error_msg,
needs_user_terminal=needs_user_terminal, needs_user_terminal=needs_user_terminal,
shell=shell,
) )

View File

@@ -0,0 +1,2 @@
from .colors import * # noqa
from .csscolors import * # noqa

View File

@@ -0,0 +1,180 @@
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import re
from functools import partial
from .csscolors import parse_rgb
# ANSI color names. There is also a "default"
COLORS: tuple[str, ...] = (
"black",
"red",
"green",
"yellow",
"blue",
"magenta",
"cyan",
"white",
)
# ANSI style names
STYLES: tuple[str, ...] = (
"none",
"bold",
"faint",
"italic",
"underline",
"blink",
"blink2",
"negative",
"concealed",
"crossed",
)
def is_string(obj: str | bytes) -> bool:
"""
Is the given object a string?
"""
return isinstance(obj, str)
def _join(*values: int | str) -> str:
"""
Join a series of values with semicolons. The values
are either integers or strings, so stringify each for
good measure. Worth breaking out as its own function
because semicolon-joined lists are core to ANSI coding.
"""
return ";".join(str(v) for v in values)
def color_code(spec: str | int | tuple[int, int, int] | list, base: int) -> str:
"""
Workhorse of encoding a color. Give preference to named colors from
ANSI, then to specific numeric or tuple specs. If those don't work,
try looking up CSS color names or parsing CSS color specifications
(hex or rgb).
:param str|int|tuple|list spec: Unparsed color specification
:param int base: Either 30 or 40, signifying the base value
for color encoding (foreground and background respectively).
Low values are added directly to the base. Higher values use `
base + 8` (i.e. 38 or 48) then extended codes.
:returns: Discovered ANSI color encoding.
:rtype: str
:raises: ValueError if cannot parse the color spec.
"""
if isinstance(spec, str | bytes):
spec = spec.strip().lower()
if spec == "default":
return _join(base + 9)
if spec in COLORS:
return _join(base + COLORS.index(spec))
if isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
if isinstance(spec, tuple | list):
return _join(base + 8, 2, _join(*spec))
rgb = parse_rgb(str(spec))
# parse_rgb raises ValueError if cannot parse spec
# or returns an rgb tuple if it can
return _join(base + 8, 2, _join(*rgb))
def color(
s: str | None = None,
fg: str | int | tuple[int, int, int] | None = None,
bg: str | int | tuple[int, int, int] | None = None,
style: str | None = None,
reset: bool = True,
) -> str:
"""
Add ANSI colors and styles to a string.
:param str s: String to format.
:param str|int|tuple fg: Foreground color specification.
:param str|int|tuple bg: Background color specification.
:param str: Style names, separated by '+'
:returns: Formatted string.
:rtype: str (or unicode in Python 2, if s is unicode)
"""
codes: list[int | str] = []
if fg:
codes.append(color_code(fg, 30))
if bg:
codes.append(color_code(bg, 40))
if style:
for style_part in style.split("+"):
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
msg = f'Invalid style "{style_part}"'
raise ValueError(msg)
if not s:
s = ""
if codes:
if reset:
template = "\x1b[{0}m{1}\x1b[0m"
else:
template = "\x1b[{0}m{1}"
return template.format(_join(*codes), s)
return s
def strip_color(s: str) -> str:
"""
Remove ANSI color/style sequences from a string. The set of all possible
ANSI sequences is large, so does not try to strip every possible one. But
does strip some outliers seen not just in text generated by this module, but
by other ANSI colorizers in the wild. Those include `\x1b[K` (aka EL or
erase to end of line) and `\x1b[m`, a terse version of the more common
`\x1b[0m`.
"""
return re.sub("\x1b\\[(K|.*?m)", "", s)
def ansilen(s: str) -> int:
"""
Given a string with embedded ANSI codes, what would its
length be without those codes?
"""
return len(strip_color(s))
# Foreground color shortcuts
black = partial(color, fg="black")
red = partial(color, fg="red")
green = partial(color, fg="green")
yellow = partial(color, fg="yellow")
blue = partial(color, fg="blue")
magenta = partial(color, fg="magenta")
cyan = partial(color, fg="cyan")
white = partial(color, fg="white")
# Style shortcuts
bold = partial(color, style="bold")
none = partial(color, style="none")
faint = partial(color, style="faint")
italic = partial(color, style="italic")
underline = partial(color, style="underline")
blink = partial(color, style="blink")
blink2 = partial(color, style="blink2")
negative = partial(color, style="negative")
concealed = partial(color, style="concealed")
crossed = partial(color, style="crossed")

View File

@@ -0,0 +1,183 @@
"""
Map of CSS color names to RGB integer values.
"""
import re
css_colors: dict[str, tuple[int, int, int]] = {
"aliceblue": (240, 248, 255),
"antiquewhite": (250, 235, 215),
"aqua": (0, 255, 255),
"aquamarine": (127, 255, 212),
"azure": (240, 255, 255),
"beige": (245, 245, 220),
"bisque": (255, 228, 196),
"black": (0, 0, 0),
"blanchedalmond": (255, 235, 205),
"blue": (0, 0, 255),
"blueviolet": (138, 43, 226),
"brown": (165, 42, 42),
"burlywood": (222, 184, 135),
"cadetblue": (95, 158, 160),
"chartreuse": (127, 255, 0),
"chocolate": (210, 105, 30),
"coral": (255, 127, 80),
"cornflowerblue": (100, 149, 237),
"cornsilk": (255, 248, 220),
"crimson": (220, 20, 60),
"cyan": (0, 255, 255),
"darkblue": (0, 0, 139),
"darkcyan": (0, 139, 139),
"darkgoldenrod": (184, 134, 11),
"darkgray": (169, 169, 169),
"darkgreen": (0, 100, 0),
"darkgrey": (169, 169, 169),
"darkkhaki": (189, 183, 107),
"darkmagenta": (139, 0, 139),
"darkolivegreen": (85, 107, 47),
"darkorange": (255, 140, 0),
"darkorchid": (153, 50, 204),
"darkred": (139, 0, 0),
"darksalmon": (233, 150, 122),
"darkseagreen": (143, 188, 143),
"darkslateblue": (72, 61, 139),
"darkslategray": (47, 79, 79),
"darkslategrey": (47, 79, 79),
"darkturquoise": (0, 206, 209),
"darkviolet": (148, 0, 211),
"deeppink": (255, 20, 147),
"deepskyblue": (0, 191, 255),
"dimgray": (105, 105, 105),
"dimgrey": (105, 105, 105),
"dodgerblue": (30, 144, 255),
"firebrick": (178, 34, 34),
"floralwhite": (255, 250, 240),
"forestgreen": (34, 139, 34),
"fuchsia": (255, 0, 255),
"gainsboro": (220, 220, 220),
"ghostwhite": (248, 248, 255),
"gold": (255, 215, 0),
"goldenrod": (218, 165, 32),
"gray": (128, 128, 128),
"green": (0, 128, 0),
"greenyellow": (173, 255, 47),
"grey": (128, 128, 128),
"honeydew": (240, 255, 240),
"hotpink": (255, 105, 180),
"indianred": (205, 92, 92),
"indigo": (75, 0, 130),
"ivory": (255, 255, 240),
"khaki": (240, 230, 140),
"lavender": (230, 230, 250),
"lavenderblush": (255, 240, 245),
"lawngreen": (124, 252, 0),
"lemonchiffon": (255, 250, 205),
"lightblue": (173, 216, 230),
"lightcoral": (240, 128, 128),
"lightcyan": (224, 255, 255),
"lightgoldenrodyellow": (250, 250, 210),
"lightgray": (211, 211, 211),
"lightgreen": (144, 238, 144),
"lightgrey": (211, 211, 211),
"lightpink": (255, 182, 193),
"lightsalmon": (255, 160, 122),
"lightseagreen": (32, 178, 170),
"lightskyblue": (135, 206, 250),
"lightslategray": (119, 136, 153),
"lightslategrey": (119, 136, 153),
"lightsteelblue": (176, 196, 222),
"lightyellow": (255, 255, 224),
"lime": (0, 255, 0),
"limegreen": (50, 205, 50),
"linen": (250, 240, 230),
"magenta": (255, 0, 255),
"maroon": (128, 0, 0),
"mediumaquamarine": (102, 205, 170),
"mediumblue": (0, 0, 205),
"mediumorchid": (186, 85, 211),
"mediumpurple": (147, 112, 219),
"mediumseagreen": (60, 179, 113),
"mediumslateblue": (123, 104, 238),
"mediumspringgreen": (0, 250, 154),
"mediumturquoise": (72, 209, 204),
"mediumvioletred": (199, 21, 133),
"midnightblue": (25, 25, 112),
"mintcream": (245, 255, 250),
"mistyrose": (255, 228, 225),
"moccasin": (255, 228, 181),
"navajowhite": (255, 222, 173),
"navy": (0, 0, 128),
"oldlace": (253, 245, 230),
"olive": (128, 128, 0),
"olivedrab": (107, 142, 35),
"orange": (255, 165, 0),
"orangered": (255, 69, 0),
"orchid": (218, 112, 214),
"palegoldenrod": (238, 232, 170),
"palegreen": (152, 251, 152),
"paleturquoise": (175, 238, 238),
"palevioletred": (219, 112, 147),
"papayawhip": (255, 239, 213),
"peachpuff": (255, 218, 185),
"peru": (205, 133, 63),
"pink": (255, 192, 203),
"plum": (221, 160, 221),
"powderblue": (176, 224, 230),
"purple": (128, 0, 128),
"rebeccapurple": (102, 51, 153),
"red": (255, 0, 0),
"rosybrown": (188, 143, 143),
"royalblue": (65, 105, 225),
"saddlebrown": (139, 69, 19),
"salmon": (250, 128, 114),
"sandybrown": (244, 164, 96),
"seagreen": (46, 139, 87),
"seashell": (255, 245, 238),
"sienna": (160, 82, 45),
"silver": (192, 192, 192),
"skyblue": (135, 206, 235),
"slateblue": (106, 90, 205),
"slategray": (112, 128, 144),
"slategrey": (112, 128, 144),
"snow": (255, 250, 250),
"springgreen": (0, 255, 127),
"steelblue": (70, 130, 180),
"tan": (210, 180, 140),
"teal": (0, 128, 128),
"thistle": (216, 191, 216),
"tomato": (255, 99, 71),
"turquoise": (64, 224, 208),
"violet": (238, 130, 238),
"wheat": (245, 222, 179),
"white": (255, 255, 255),
"whitesmoke": (245, 245, 245),
"yellow": (255, 255, 0),
"yellowgreen": (154, 205, 50),
}
def parse_rgb(s: str) -> tuple[int, ...]:
s = s.strip().replace(" ", "").lower()
# simple lookup
rgb = css_colors.get(s)
if rgb is not None:
return rgb
# 6-digit hex
match = re.match("#([a-f0-9]{6})$", s)
if match:
core = match.group(1)
return tuple(int(core[i : i + 2], 16) for i in range(0, 6, 2))
# 3-digit hex
match = re.match("#([a-f0-9]{3})$", s)
if match:
return tuple(int(c * 2, 16) for c in match.group(1))
# rgb(x,y,z)
match = re.match(r"rgb\((\d+,\d+,\d+)\)", s)
if match:
return tuple(int(v) for v in match.group(1).split(","))
msg = f"Could not parse color '{s}'"
raise ValueError(msg)

View File

@@ -1,61 +1,74 @@
import inspect import inspect
import logging import logging
import os import os
from collections.abc import Callable import sys
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
grey = "\x1b[38;20m" from clan_cli.colors import color, css_colors
yellow = "\x1b[33;20m"
red = "\x1b[31;20m" # https://no-color.org
bold_red = "\x1b[31;1m" DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != ""
green = "\u001b[32m"
blue = "\u001b[34m"
def get_formatter(color: str) -> Callable[[logging.LogRecord, bool], logging.Formatter]: def _get_filepath(record: logging.LogRecord) -> Path:
def myformatter( try:
record: logging.LogRecord, with_location: bool filepath = Path(record.pathname).resolve()
) -> logging.Formatter: filepath = Path("~", filepath.relative_to(Path.home()))
reset = "\x1b[0m" except Exception:
filepath = Path(record.pathname)
try: return filepath
filepath = Path(record.pathname).resolve()
filepath = Path("~", filepath.relative_to(Path.home()))
except Exception:
filepath = Path(record.pathname)
if not with_location:
return logging.Formatter(f"{color}%(levelname)s{reset}: %(message)s")
return logging.Formatter(
f"{color}%(levelname)s{reset}: %(message)s\nLocation: {filepath}:%(lineno)d::%(funcName)s\n"
)
return myformatter
FORMATTER = { class PrefixFormatter(logging.Formatter):
logging.DEBUG: get_formatter(blue), """
logging.INFO: get_formatter(green), print errors in red and warnings in yellow
logging.WARNING: get_formatter(yellow), """
logging.ERROR: get_formatter(red),
logging.CRITICAL: get_formatter(bold_red),
}
def __init__(self, trace_prints: bool = False, default_prefix: str = "") -> None:
self.default_prefix = default_prefix
self.trace_prints = trace_prints
class CustomFormatter(logging.Formatter):
def __init__(self, log_locations: bool) -> None:
super().__init__() super().__init__()
self.log_locations = log_locations self.hostnames: list[str] = []
self.hostname_color_offset = 1 # first host shouldn't get aggressive red
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
return FORMATTER[record.levelno](record, self.log_locations).format(record) filepath = _get_filepath(record)
if record.levelno == logging.DEBUG:
color_str = "blue"
elif record.levelno == logging.ERROR:
color_str = "red"
elif record.levelno == logging.WARNING:
color_str = "yellow"
else:
color_str = None
class ThreadFormatter(logging.Formatter): command_prefix = getattr(record, "command_prefix", self.default_prefix)
def format(self, record: logging.LogRecord) -> str:
return FORMATTER[record.levelno](record, False).format(record) if not DISABLE_COLOR:
prefix_color = self.hostname_colorcode(command_prefix)
format_str = color(f"[{command_prefix}]", fg=prefix_color)
format_str += color(" %(message)s", fg=color_str)
else:
format_str = f"[{command_prefix}] %(message)s"
if self.trace_prints:
format_str += f"\nSource: {filepath}:%(lineno)d::%(funcName)s\n"
return logging.Formatter(format_str).format(record)
def hostname_colorcode(self, hostname: str) -> tuple[int, int, int]:
try:
index = self.hostnames.index(hostname)
except ValueError:
self.hostnames += [hostname]
index = self.hostnames.index(hostname)
coloroffset = (index + self.hostname_color_offset) % len(css_colors)
colorcode = list(css_colors.values())[coloroffset]
return colorcode
def get_callers(start: int = 2, end: int = 2) -> list[str]: def get_callers(start: int = 2, end: int = 2) -> list[str]:
@@ -103,7 +116,28 @@ def get_callers(start: int = 2, end: int = 2) -> list[str]:
return callers return callers
def setup_logging(level: Any, root_log_name: str = __name__.split(".")[0]) -> None: def print_trace(msg: str, logger: logging.Logger, prefix: str) -> None:
trace_depth = int(os.environ.get("TRACE_DEPTH", "0"))
callers = get_callers(3, 4 + trace_depth)
if "run_no_stdout" in callers[0]:
callers = callers[1:]
else:
callers.pop()
if len(callers) == 1:
callers_str = f"Caller: {callers[0]}\n"
else:
callers_str = "\n".join(f"{i+1}: {caller}" for i, caller in enumerate(callers))
callers_str = f"Callers:\n{callers_str}"
logger.debug(f"{msg} \n{callers_str}", extra={"command_prefix": prefix})
def setup_logging(
level: Any,
root_log_name: str = __name__.split(".")[0],
default_prefix: str = "clan",
) -> None:
# Get the root logger and set its level # Get the root logger and set its level
main_logger = logging.getLogger(root_log_name) main_logger = logging.getLogger(root_log_name)
main_logger.setLevel(level) main_logger.setLevel(level)
@@ -113,10 +147,6 @@ def setup_logging(level: Any, root_log_name: str = __name__.split(".")[0]) -> No
# Create and add your custom handler # Create and add your custom handler
default_handler.setLevel(level) default_handler.setLevel(level)
trace_depth = bool(int(os.environ.get("TRACE_DEPTH", "0"))) trace_prints = bool(int(os.environ.get("TRACE_PRINT", "0")))
default_handler.setFormatter(CustomFormatter(trace_depth)) default_handler.setFormatter(PrefixFormatter(trace_prints, default_prefix))
main_logger.addHandler(default_handler) main_logger.addHandler(default_handler)
# Set logging level for other modules used by this module
logging.getLogger("asyncio").setLevel(logging.INFO)
logging.getLogger("httpx").setLevel(level=logging.WARNING)

View File

@@ -74,7 +74,7 @@ def indent_command(command_list: list[str]) -> str:
# Indent after the next argument # Indent after the next argument
formatted_command.append(" ") formatted_command.append(" ")
i += 1 i += 1
formatted_command.append(shlex.quote(command_list[i])) formatted_command.append(command_list[i])
if i < len(command_list) - 1: if i < len(command_list) - 1:
# Add line continuation only if it's not the last argument # Add line continuation only if it's not the last argument

View File

@@ -204,7 +204,7 @@ def generate_facts(
machine, service, regenerate, tmpdir, prompt machine, service, regenerate, tmpdir, prompt
) )
except (OSError, ClanError): except (OSError, ClanError):
log.exception(f"Failed to generate facts for {machine.name}") machine.error("Failed to generate facts")
errors += 1 errors += 1
if errors > 0: if errors > 0:
msg = ( msg = (
@@ -213,7 +213,7 @@ def generate_facts(
raise ClanError(msg) raise ClanError(msg)
if not was_regenerated: if not was_regenerated:
print("All secrets and facts are already up to date") machine.info("All secrets and facts are already up to date")
return was_regenerated return was_regenerated

View File

@@ -3,6 +3,7 @@ import subprocess
from pathlib import Path from pathlib import Path
from typing import override from typing import override
from clan_cli.cmd import Log
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell from clan_cli.nix import nix_shell
@@ -97,8 +98,8 @@ class SecretStore(SecretStoreBase):
remote_hash = self.machine.target_host.run( remote_hash = self.machine.target_host.run(
# TODO get the path to the secrets from the machine # TODO get the path to the secrets from the machine
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"], ["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
log=Log.STDERR,
check=False, check=False,
stdout=subprocess.PIPE,
).stdout.strip() ).stdout.strip()
if not remote_hash: if not remote_hash:

View File

@@ -48,6 +48,18 @@ class Machine:
def __repr__(self) -> str: def __repr__(self) -> str:
return str(self) return str(self)
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
kwargs.update({"extra": {"command_prefix": self.name}})
log.debug(msg, *args, **kwargs)
def info(self, msg: str, *args: Any, **kwargs: Any) -> None:
kwargs.update({"extra": {"command_prefix": self.name}})
log.info(msg, *args, **kwargs)
def error(self, msg: str, *args: Any, **kwargs: Any) -> None:
kwargs.update({"extra": {"command_prefix": self.name}})
log.error(msg, *args, **kwargs)
@property @property
def system(self) -> str: def system(self) -> str:
# We filter out function attributes because they are not serializable. # We filter out function attributes because they are not serializable.

View File

@@ -64,7 +64,7 @@ def upload_sources(machine: Machine, always_upload_source: bool = False) -> str:
path, path,
] ]
) )
run(cmd, env=env, error_msg="failed to upload sources") run(cmd, env=env, error_msg="failed to upload sources", prefix=machine.name)
return path return path
# Slow path: we need to upload all sources to the remote machine # Slow path: we need to upload all sources to the remote machine
@@ -78,7 +78,6 @@ def upload_sources(machine: Machine, always_upload_source: bool = False) -> str:
flake_url, flake_url,
] ]
) )
log.info("run %s", shlex.join(cmd))
proc = run(cmd, env=env, error_msg="failed to upload sources") proc = run(cmd, env=env, error_msg="failed to upload sources")
try: try:

View File

@@ -1,47 +1,28 @@
# Adapted from https://github.com/numtide/deploykit # Adapted from https://github.com/numtide/deploykit
import fcntl import logging
import math import math
import os import os
import select
import shlex import shlex
import subprocess import subprocess
import tarfile import tarfile
import time
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path from pathlib import Path
from shlex import quote from shlex import quote
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import IO, Any from typing import IO, Any
from clan_cli.cmd import Log, terminate_process_group from clan_cli.cmd import Log
from clan_cli.cmd import run as local_run from clan_cli.cmd import run as local_run
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.logger import cmdlog
FILE = None | int cmdlog = logging.getLogger(__name__)
# Seconds until a message is printed when _run produces no output. # Seconds until a message is printed when _run produces no output.
NO_OUTPUT_TIMEOUT = 20 NO_OUTPUT_TIMEOUT = 20
@contextmanager
def _pipe() -> Iterator[tuple[IO[str], IO[str]]]:
(pipe_r, pipe_w) = os.pipe()
read_end = os.fdopen(pipe_r, "r")
write_end = os.fdopen(pipe_w, "w")
try:
fl = fcntl.fcntl(read_end, fcntl.F_GETFL)
fcntl.fcntl(read_end, fcntl.F_SETFL, fl | os.O_NONBLOCK)
yield (read_end, write_end)
finally:
read_end.close()
write_end.close()
class Host: class Host:
def __init__( def __init__(
self, self,
@@ -101,196 +82,56 @@ class Host:
host = f"[{host}]" host = f"[{host}]"
return f"{self.user or 'root'}@{host}" return f"{self.user or 'root'}@{host}"
def _prefix_output(
self,
displayed_cmd: str,
print_std_fd: IO[str] | None,
print_err_fd: IO[str] | None,
stdout: IO[str] | None,
stderr: IO[str] | None,
timeout: float = math.inf,
) -> tuple[str, str]:
rlist = []
if print_std_fd is not None:
rlist.append(print_std_fd)
if print_err_fd is not None:
rlist.append(print_err_fd)
if stdout is not None:
rlist.append(stdout)
if stderr is not None:
rlist.append(stderr)
print_std_buf = ""
print_err_buf = ""
stdout_buf = ""
stderr_buf = ""
start = time.time()
last_output = time.time()
while len(rlist) != 0:
readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False
) -> tuple[float, str]:
read = os.read(print_fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(print_fd)
print_buf += read.decode("utf-8")
if (read == b"" and len(print_buf) != 0) or "\n" in print_buf:
# print and empty the print_buf, if the stream is draining,
# but there is still something in the buffer or on newline.
lines = print_buf.rstrip("\n").split("\n")
for line in lines:
if not is_err:
cmdlog.info(
line, extra={"command_prefix": self.command_prefix}
)
else:
cmdlog.error(
line, extra={"command_prefix": self.command_prefix}
)
print_buf = ""
last_output = time.time()
return (last_output, print_buf)
if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False
)
if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True
)
now = time.time()
elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix},
)
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(fd)
else:
return read.decode("utf-8")
return ""
stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout:
break
return stdout_buf, stderr_buf
def _run( def _run(
self, self,
cmd: list[str], cmd: list[str],
displayed_cmd: str, *,
shell: bool, stdout: IO[bytes] | None = None,
stdout: FILE = None, stderr: IO[bytes] | None = None,
stderr: FILE = None, input: bytes | None = None, # noqa: A002
extra_env: dict[str, str] | None = None, env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: Path | None = None,
log: Log = Log.BOTH,
check: bool = True, check: bool = True,
timeout: float = math.inf, error_msg: str | None = None,
needs_user_terminal: bool = False, needs_user_terminal: bool = False,
shell: bool = False,
timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]: ) -> subprocess.CompletedProcess[str]:
if extra_env is None: res = local_run(
extra_env = {} cmd,
with ExitStack() as stack: shell=shell,
read_std_fd, write_std_fd = (None, None) stdout=stdout,
read_err_fd, write_err_fd = (None, None) prefix=self.command_prefix,
timeout=timeout,
if stdout is None or stderr is None: stderr=stderr,
read_std_fd, write_std_fd = stack.enter_context(_pipe()) input=input,
read_err_fd, write_err_fd = stack.enter_context(_pipe()) env=env,
cwd=cwd,
if stdout is None: log=log,
stdout_read = None logger=cmdlog,
stdout_write = write_std_fd check=check,
elif stdout == subprocess.PIPE: error_msg=error_msg,
stdout_read, stdout_write = stack.enter_context(_pipe()) needs_user_terminal=needs_user_terminal,
else: )
msg = f"unsupported value for stdout parameter: {stdout}" return subprocess.CompletedProcess(
raise ClanError(msg) args=res.command_list,
returncode=res.returncode,
if stderr is None: stdout=res.stdout,
stderr_read = None stderr=res.stderr,
stderr_write = write_err_fd )
elif stderr == subprocess.PIPE:
stderr_read, stderr_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stderr parameter: {stderr}"
raise ClanError(msg)
env = os.environ.copy()
env.update(extra_env)
with subprocess.Popen(
cmd,
text=True,
shell=shell,
stdout=stdout_write,
stderr=stderr_write,
env=env,
cwd=cwd,
start_new_session=not needs_user_terminal,
) as p:
if not needs_user_terminal:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None:
write_std_fd.close()
if write_err_fd is not None:
write_err_fd.close()
if stdout == subprocess.PIPE:
assert stdout_write is not None
stdout_write.close()
if stderr == subprocess.PIPE:
assert stderr_write is not None
stderr_write.close()
start = time.time()
stdout_data, stderr_data = self._prefix_output(
displayed_cmd,
read_std_fd,
read_err_fd,
stdout_read,
stderr_read,
timeout,
)
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
if ret != 0:
if check:
msg = f"Command {shlex.join(cmd)} failed with return code {ret}"
raise ClanError(msg)
cmdlog.warning(
f"[Command failed: {ret}] {displayed_cmd}",
extra={"command_prefix": self.command_prefix},
)
return subprocess.CompletedProcess(
cmd, ret, stdout=stdout_data, stderr=stderr_data
)
msg = "unreachable"
raise RuntimeError(msg)
def run_local( def run_local(
self, self,
cmd: str | list[str], cmd: list[str],
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
shell: bool = False,
log: Log = Log.BOTH,
) -> subprocess.CompletedProcess[str]: ) -> subprocess.CompletedProcess[str]:
""" """
Command to run locally for the host Command to run locally for the host
@@ -304,38 +145,38 @@ class Host:
@return subprocess.CompletedProcess result of the command @return subprocess.CompletedProcess result of the command
""" """
if extra_env is None: env = os.environ.copy()
extra_env = {} if extra_env:
shell = False env.update(extra_env)
if isinstance(cmd, str):
cmd = [cmd]
shell = True
displayed_cmd = " ".join(cmd) displayed_cmd = " ".join(cmd)
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix}) cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
return self._run( return self._run(
cmd, cmd,
displayed_cmd,
shell=shell, shell=shell,
stdout=stdout, stdout=stdout,
stderr=stderr, stderr=stderr,
extra_env=extra_env, env=env,
cwd=cwd, cwd=cwd,
check=check, check=check,
timeout=timeout, timeout=timeout,
log=log,
) )
def run( def run(
self, self,
cmd: str | list[str], cmd: list[str],
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
become_root: bool = False, become_root: bool = False,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> subprocess.CompletedProcess[str]: ) -> subprocess.CompletedProcess[str]:
""" """
Command to run on the host via ssh Command to run on the host via ssh
@@ -353,48 +194,50 @@ class Host:
""" """
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
# If we are not root and we need to become root, prepend sudo
sudo = "" sudo = ""
if become_root and self.user != "root": if become_root and self.user != "root":
sudo = "sudo -- " sudo = "sudo -- "
# Quote all added environment variables
env_vars = [] env_vars = []
for k, v in extra_env.items(): for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}") env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
# Build a pretty command for logging
displayed_cmd = "" displayed_cmd = ""
export_cmd = "" export_cmd = ""
if env_vars: if env_vars:
export_cmd = f"export {' '.join(env_vars)}; " export_cmd = f"export {' '.join(env_vars)}; "
displayed_cmd += export_cmd displayed_cmd += export_cmd
if isinstance(cmd, list): displayed_cmd += " ".join(cmd)
displayed_cmd += " ".join(cmd)
else:
displayed_cmd += cmd
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix}) cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
# Build the ssh command
bash_cmd = export_cmd bash_cmd = export_cmd
bash_args = [] if shell:
if isinstance(cmd, list): bash_cmd += " ".join(cmd)
bash_cmd += 'exec "$@"'
bash_args += cmd
else: else:
bash_cmd += cmd bash_cmd += 'exec "$@"'
# FIXME we assume bash to be present here? Should be documented... # FIXME we assume bash to be present here? Should be documented...
ssh_cmd = [ ssh_cmd = [
*self.ssh_cmd(verbose_ssh=verbose_ssh, tty=tty), *self.ssh_cmd(verbose_ssh=verbose_ssh, tty=tty),
"--", "--",
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, bash_args))}", f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}",
] ]
# Run the ssh command
return self._run( return self._run(
ssh_cmd, ssh_cmd,
displayed_cmd,
shell=False, shell=False,
stdout=stdout, stdout=stdout,
stderr=stderr, stderr=stderr,
log=log,
cwd=cwd, cwd=cwd,
check=check, check=check,
timeout=timeout, timeout=timeout,
# all ssh commands can potentially ask for a password needs_user_terminal=True, # ssh asks for a password
needs_user_terminal=True,
) )
def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]: def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]:
@@ -464,13 +307,19 @@ class Host:
"tar", "tar",
"-C", "-C",
str(remote_dest), str(remote_dest),
"-xvzf", "-xzf",
"-", "-",
] ]
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory. # TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f: with tar_path.open("rb") as f:
local_run(cmd, input=f.read(), log=Log.BOTH, needs_user_terminal=True) local_run(
cmd,
input=f.read(),
log=Log.BOTH,
needs_user_terminal=True,
prefix=self.command_prefix,
)
@property @property
def ssh_cmd_opts( def ssh_cmd_opts(

View File

@@ -1,15 +1,18 @@
import logging
import math import math
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Any from typing import IO, Any
from clan_cli.cmd import Log
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.ssh import T from clan_cli.ssh import T
from clan_cli.ssh.host import FILE, Host from clan_cli.ssh.host import Host
from clan_cli.ssh.logger import cmdlog
from clan_cli.ssh.results import HostResult, Results from clan_cli.ssh.results import HostResult, Results
cmdlog = logging.getLogger(__name__)
def _worker( def _worker(
func: Callable[[Host], T], func: Callable[[Host], T],
@@ -35,17 +38,19 @@ class HostGroup:
def _run_local( def _run_local(
self, self,
cmd: str | list[str], cmd: list[str],
host: Host, host: Host,
results: Results, results: Results,
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf, timeout: float = math.inf,
shell: bool = False,
tty: bool = False, tty: bool = False,
log: Log = Log.BOTH,
) -> None: ) -> None:
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
@@ -58,6 +63,8 @@ class HostGroup:
cwd=cwd, cwd=cwd,
check=check, check=check,
timeout=timeout, timeout=timeout,
shell=shell,
log=log,
) )
results.append(HostResult(host, proc)) results.append(HostResult(host, proc))
except Exception as e: except Exception as e:
@@ -65,17 +72,19 @@ class HostGroup:
def _run_remote( def _run_remote(
self, self,
cmd: str | list[str], cmd: list[str],
host: Host, host: Host,
results: Results, results: Results,
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf, timeout: float = math.inf,
tty: bool = False, tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> None: ) -> None:
if cwd is not None: if cwd is not None:
msg = "cwd is not supported for remote commands" msg = "cwd is not supported for remote commands"
@@ -93,6 +102,8 @@ class HostGroup:
verbose_ssh=verbose_ssh, verbose_ssh=verbose_ssh,
timeout=timeout, timeout=timeout,
tty=tty, tty=tty,
shell=shell,
log=log,
) )
results.append(HostResult(host, proc)) results.append(HostResult(host, proc))
except Exception as e: except Exception as e:
@@ -114,16 +125,18 @@ class HostGroup:
def _run( def _run(
self, self,
cmd: str | list[str], cmd: list[str],
local: bool = False, local: bool = False,
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results: ) -> Results:
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
@@ -145,6 +158,8 @@ class HostGroup:
"timeout": timeout, "timeout": timeout,
"verbose_ssh": verbose_ssh, "verbose_ssh": verbose_ssh,
"tty": tty, "tty": tty,
"shell": shell,
"log": log,
}, },
) )
thread.start() thread.start()
@@ -160,15 +175,17 @@ class HostGroup:
def run( def run(
self, self,
cmd: str | list[str], cmd: list[str],
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf, timeout: float = math.inf,
tty: bool = False, tty: bool = False,
log: Log = Log.BOTH,
shell: bool = False,
) -> Results: ) -> Results:
""" """
Command to run on the remote host via ssh Command to run on the remote host via ssh
@@ -184,6 +201,7 @@ class HostGroup:
extra_env = {} extra_env = {}
return self._run( return self._run(
cmd, cmd,
shell=shell,
stdout=stdout, stdout=stdout,
stderr=stderr, stderr=stderr,
extra_env=extra_env, extra_env=extra_env,
@@ -192,17 +210,20 @@ class HostGroup:
verbose_ssh=verbose_ssh, verbose_ssh=verbose_ssh,
timeout=timeout, timeout=timeout,
tty=tty, tty=tty,
log=log,
) )
def run_local( def run_local(
self, self,
cmd: str | list[str], cmd: list[str],
stdout: FILE = None, stdout: IO[bytes] | None = None,
stderr: FILE = None, stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results: ) -> Results:
""" """
Command to run locally for each host in the group in parallel Command to run locally for each host in the group in parallel
@@ -226,6 +247,8 @@ class HostGroup:
cwd=cwd, cwd=cwd,
check=check, check=check,
timeout=timeout, timeout=timeout,
shell=shell,
log=log,
) )
def run_function( def run_function(

View File

@@ -1,87 +0,0 @@
# Adapted from https://github.com/numtide/deploykit
import logging
import os
import sys
# https://no-color.org
DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != ""
def ansi_color(color: int) -> str:
return f"\x1b[{color}m"
class CommandFormatter(logging.Formatter):
"""
print errors in red and warnings in yellow
"""
def __init__(self) -> None:
super().__init__(
"%(prefix_color)s[%(command_prefix)s]%(color_reset)s %(color)s%(message)s%(color_reset)s"
)
self.hostnames: list[str] = []
self.hostname_color_offset = 1 # first host shouldn't get aggressive red
def format(self, record: logging.LogRecord) -> str:
colorcode = 0
if record.levelno == logging.ERROR:
colorcode = 31 # red
if record.levelno == logging.WARNING:
colorcode = 33 # yellow
color, prefix_color, color_reset = "", "", ""
if not DISABLE_COLOR:
command_prefix = getattr(record, "command_prefix", "")
color = ansi_color(colorcode)
prefix_color = ansi_color(self.hostname_colorcode(command_prefix))
color_reset = "\x1b[0m"
record.color = color
record.prefix_color = prefix_color
record.color_reset = color_reset
return super().format(record)
def hostname_colorcode(self, hostname: str) -> int:
try:
index = self.hostnames.index(hostname)
except ValueError:
self.hostnames += [hostname]
index = self.hostnames.index(hostname)
return 31 + (index + self.hostname_color_offset) % 7
def setup_loggers() -> tuple[logging.Logger, logging.Logger]:
# If we use the default logger here (logging.error etc) or a logger called
# "deploykit", then cmdlog messages are also posted on the default logger.
# To avoid this message duplication, we set up a main and command logger
# and use a "deploykit" main logger.
kitlog = logging.getLogger("deploykit.main")
kitlog.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(logging.Formatter())
kitlog.addHandler(ch)
# use specific logger for command outputs
cmdlog = logging.getLogger("deploykit.command")
cmdlog.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(CommandFormatter())
cmdlog.addHandler(ch)
return (kitlog, cmdlog)
# loggers for: general deploykit, command output
kitlog, cmdlog = setup_loggers()
info = kitlog.info
warn = kitlog.warning
error = kitlog.error

View File

@@ -1,790 +0,0 @@
# Adapted from https://github.com/numtide/deploykit
import fcntl
import math
import os
import select
import shlex
import subprocess
import tarfile
import time
from collections.abc import Callable, Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path
from shlex import quote
from tempfile import TemporaryDirectory
from threading import Thread
from typing import IO, Any, Generic, TypeVar
from clan_cli.cmd import Log, terminate_process_group
from clan_cli.cmd import run as local_run
from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.logger import cmdlog
FILE = None | int
# Seconds until a message is printed when _run produces no output.
NO_OUTPUT_TIMEOUT = 20
@contextmanager
def _pipe() -> Iterator[tuple[IO[str], IO[str]]]:
(pipe_r, pipe_w) = os.pipe()
read_end = os.fdopen(pipe_r, "r")
write_end = os.fdopen(pipe_w, "w")
try:
fl = fcntl.fcntl(read_end, fcntl.F_GETFL)
fcntl.fcntl(read_end, fcntl.F_SETFL, fl | os.O_NONBLOCK)
yield (read_end, write_end)
finally:
read_end.close()
write_end.close()
class Host:
def __init__(
self,
host: str,
user: str | None = None,
port: int | None = None,
key: str | None = None,
forward_agent: bool = False,
command_prefix: str | None = None,
host_key_check: HostKeyCheck = HostKeyCheck.ASK,
meta: dict[str, Any] | None = None,
verbose_ssh: bool = False,
ssh_options: dict[str, str] | None = None,
) -> None:
"""
Creates a Host
@host the hostname to connect to via ssh
@port the port to connect to via ssh
@forward_agent: whether to forward ssh agent
@command_prefix: string to prefix each line of the command output with, defaults to host
@host_key_check: whether to check ssh host keys
@verbose_ssh: Enables verbose logging on ssh connections
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
"""
if ssh_options is None:
ssh_options = {}
if meta is None:
meta = {}
self.host = host
self.user = user
self.port = port
self.key = key
if command_prefix:
self.command_prefix = command_prefix
else:
self.command_prefix = host
self.forward_agent = forward_agent
self.host_key_check = host_key_check
self.meta = meta
self.verbose_ssh = verbose_ssh
self._ssh_options = ssh_options
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"{self.user}@{self.host}" + str(self.port if self.port else "")
@property
def target(self) -> str:
return f"{self.user or 'root'}@{self.host}"
@property
def target_for_rsync(self) -> str:
host = self.host
if ":" in host:
host = f"[{host}]"
return f"{self.user or 'root'}@{host}"
def _prefix_output(
self,
displayed_cmd: str,
print_std_fd: IO[str] | None,
print_err_fd: IO[str] | None,
stdout: IO[str] | None,
stderr: IO[str] | None,
timeout: float = math.inf,
) -> tuple[str, str]:
rlist = []
if print_std_fd is not None:
rlist.append(print_std_fd)
if print_err_fd is not None:
rlist.append(print_err_fd)
if stdout is not None:
rlist.append(stdout)
if stderr is not None:
rlist.append(stderr)
print_std_buf = ""
print_err_buf = ""
stdout_buf = ""
stderr_buf = ""
start = time.time()
last_output = time.time()
while len(rlist) != 0:
readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False
) -> tuple[float, str]:
read = os.read(print_fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(print_fd)
print_buf += read.decode("utf-8")
if (read == b"" and len(print_buf) != 0) or "\n" in print_buf:
# print and empty the print_buf, if the stream is draining,
# but there is still something in the buffer or on newline.
lines = print_buf.rstrip("\n").split("\n")
for line in lines:
if not is_err:
cmdlog.info(
line, extra={"command_prefix": self.command_prefix}
)
else:
cmdlog.error(
line, extra={"command_prefix": self.command_prefix}
)
print_buf = ""
last_output = time.time()
return (last_output, print_buf)
if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False
)
if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True
)
now = time.time()
elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix},
)
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(fd)
else:
return read.decode("utf-8")
return ""
stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout:
break
return stdout_buf, stderr_buf
def _run(
self,
cmd: list[str],
displayed_cmd: str,
shell: bool,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
needs_user_terminal: bool = False,
) -> subprocess.CompletedProcess[str]:
if extra_env is None:
extra_env = {}
with ExitStack() as stack:
read_std_fd, write_std_fd = (None, None)
read_err_fd, write_err_fd = (None, None)
if stdout is None or stderr is None:
read_std_fd, write_std_fd = stack.enter_context(_pipe())
read_err_fd, write_err_fd = stack.enter_context(_pipe())
if stdout is None:
stdout_read = None
stdout_write = write_std_fd
elif stdout == subprocess.PIPE:
stdout_read, stdout_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stdout parameter: {stdout}"
raise ClanError(msg)
if stderr is None:
stderr_read = None
stderr_write = write_err_fd
elif stderr == subprocess.PIPE:
stderr_read, stderr_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stderr parameter: {stderr}"
raise ClanError(msg)
env = os.environ.copy()
env.update(extra_env)
with subprocess.Popen(
cmd,
text=True,
shell=shell,
stdout=stdout_write,
stderr=stderr_write,
env=env,
cwd=cwd,
start_new_session=not needs_user_terminal,
) as p:
if not needs_user_terminal:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None:
write_std_fd.close()
if write_err_fd is not None:
write_err_fd.close()
if stdout == subprocess.PIPE:
assert stdout_write is not None
stdout_write.close()
if stderr == subprocess.PIPE:
assert stderr_write is not None
stderr_write.close()
start = time.time()
stdout_data, stderr_data = self._prefix_output(
displayed_cmd,
read_std_fd,
read_err_fd,
stdout_read,
stderr_read,
timeout,
)
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
if ret != 0:
if check:
msg = f"Command {shlex.join(cmd)} failed with return code {ret}"
raise ClanError(msg)
cmdlog.warning(
f"[Command failed: {ret}] {displayed_cmd}",
extra={"command_prefix": self.command_prefix},
)
return subprocess.CompletedProcess(
cmd, ret, stdout=stdout_data, stderr=stderr_data
)
msg = "unreachable"
raise RuntimeError(msg)
def run_local(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]:
"""
Command to run locally for the host
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocess.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@extra_env environment variables to override when running the command
@cwd current working directory to run the process in
@timeout: Timeout in seconds for the command to complete
@return subprocess.CompletedProcess result of the command
"""
if extra_env is None:
extra_env = {}
shell = False
if isinstance(cmd, str):
cmd = [cmd]
shell = True
displayed_cmd = " ".join(cmd)
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
return self._run(
cmd,
displayed_cmd,
shell=shell,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
)
def run(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
become_root: bool = False,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
) -> subprocess.CompletedProcess[str]:
"""
Command to run on the host via ssh
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@become_root if the ssh_user is not root than sudo is prepended
@extra_env environment variables to override when running the command
@cwd current working directory to run the process in
@verbose_ssh: Enables verbose logging on ssh connections
@timeout: Timeout in seconds for the command to complete
@return subprocess.CompletedProcess result of the ssh command
"""
if extra_env is None:
extra_env = {}
sudo = ""
if become_root and self.user != "root":
sudo = "sudo -- "
env_vars = []
for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
displayed_cmd = ""
export_cmd = ""
if env_vars:
export_cmd = f"export {' '.join(env_vars)}; "
displayed_cmd += export_cmd
if isinstance(cmd, list):
displayed_cmd += " ".join(cmd)
else:
displayed_cmd += cmd
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
bash_cmd = export_cmd
bash_args = []
if isinstance(cmd, list):
bash_cmd += 'exec "$@"'
bash_args += cmd
else:
bash_cmd += cmd
# FIXME we assume bash to be present here? Should be documented...
ssh_cmd = [
*self.ssh_cmd(verbose_ssh=verbose_ssh, tty=tty),
"--",
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, bash_args))}",
]
return self._run(
ssh_cmd,
displayed_cmd,
shell=False,
stdout=stdout,
stderr=stderr,
cwd=cwd,
check=check,
timeout=timeout,
# all ssh commands can potentially ask for a password
needs_user_terminal=True,
)
def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]:
if env is None:
env = {}
env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts)
return env
def upload(
self,
local_src: Path, # must be a directory
remote_dest: Path, # must be a directory
file_user: str = "root",
file_group: str = "root",
dir_mode: int = 0o700,
file_mode: int = 0o400,
) -> None:
# check if the remote destination is a directory (no suffix)
if remote_dest.suffix:
msg = "Only directories are allowed"
raise ClanError(msg)
if not local_src.is_dir():
msg = "Only directories are allowed"
raise ClanError(msg)
# Create the tarball from the temporary directory
with TemporaryDirectory(prefix="facts-upload-") as tardir:
tar_path = Path(tardir) / "upload.tar.gz"
# We set the permissions of the files and directories in the tarball to read only and owned by root
# As first uploading the tarball and then changing the permissions can lead an attacker to
# do a race condition attack
with tarfile.open(str(tar_path), "w:gz") as tar:
for root, dirs, files in local_src.walk():
for mdir in dirs:
dir_path = Path(root) / mdir
tarinfo = tar.gettarinfo(
dir_path, arcname=str(dir_path.relative_to(str(local_src)))
)
tarinfo.mode = dir_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
tar.addfile(tarinfo)
for file in files:
file_path = Path(root) / file
tarinfo = tar.gettarinfo(
file_path,
arcname=str(file_path.relative_to(str(local_src))),
)
tarinfo.mode = file_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
with file_path.open("rb") as f:
tar.addfile(tarinfo, f)
cmd = [
*self.ssh_cmd(),
"rm",
"-r",
str(remote_dest),
";",
"mkdir",
f"--mode={dir_mode:o}",
"-p",
str(remote_dest),
"&&",
"tar",
"-C",
str(remote_dest),
"-xvzf",
"-",
]
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f:
local_run(cmd, input=f.read(), log=Log.BOTH, needs_user_terminal=True)
@property
def ssh_cmd_opts(
self,
) -> list[str]:
ssh_opts = ["-A"] if self.forward_agent else []
for k, v in self._ssh_options.items():
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
ssh_opts.extend(self.host_key_check.to_ssh_opt())
return ssh_opts
def ssh_cmd(
self,
verbose_ssh: bool = False,
tty: bool = False,
) -> list[str]:
ssh_opts = self.ssh_cmd_opts
if verbose_ssh or self.verbose_ssh:
ssh_opts.extend(["-v"])
if tty:
ssh_opts.extend(["-t"])
if self.port:
ssh_opts.extend(["-p", str(self.port)])
if self.key:
ssh_opts.extend(["-i", self.key])
return [
"ssh",
self.target,
*ssh_opts,
]
T = TypeVar("T")
class HostResult(Generic[T]):
def __init__(self, host: Host, result: T | Exception) -> None:
self.host = host
self._result = result
@property
def error(self) -> Exception | None:
"""
Returns an error if the command failed
"""
if isinstance(self._result, Exception):
return self._result
return None
@property
def result(self) -> T:
"""
Unwrap the result
"""
if isinstance(self._result, Exception):
raise self._result
return self._result
Results = list[HostResult[subprocess.CompletedProcess[str]]]
def _worker(
func: Callable[[Host], T],
host: Host,
results: list[HostResult[T]],
idx: int,
) -> None:
try:
results[idx] = HostResult(host, func(host))
except Exception as e:
results[idx] = HostResult(host, e)
class HostGroup:
def __init__(self, hosts: list[Host]) -> None:
self.hosts = hosts
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"HostGroup({self.hosts})"
def _run_local(
self,
cmd: str | list[str],
host: Host,
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run_local(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
)
results.append(HostResult(host, proc))
except Exception as e:
results.append(HostResult(host, e))
def _run_remote(
self,
cmd: str | list[str],
host: Host,
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if cwd is not None:
msg = "cwd is not supported for remote commands"
raise ClanError(msg)
if extra_env is None:
extra_env = {}
try:
proc = host.run(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty,
)
results.append(HostResult(host, proc))
except Exception as e:
results.append(HostResult(host, e))
def _reraise_errors(self, results: list[HostResult[Any]]) -> None:
errors = 0
for result in results:
e = result.error
if e:
cmdlog.error(
f"failed with: {e}",
extra={"command_prefix": result.host.command_prefix},
)
errors += 1
if errors > 0:
msg = f"{errors} hosts failed with an error. Check the logs above"
raise ClanError(msg)
def _run(
self,
cmd: str | list[str],
local: bool = False,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
if extra_env is None:
extra_env = {}
results: Results = []
threads = []
for host in self.hosts:
fn = self._run_local if local else self._run_remote
thread = Thread(
target=fn,
kwargs={
"results": results,
"cmd": cmd,
"host": host,
"stdout": stdout,
"stderr": stderr,
"extra_env": extra_env,
"cwd": cwd,
"check": check,
"timeout": timeout,
"verbose_ssh": verbose_ssh,
"tty": tty,
},
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if check:
self._reraise_errors(results)
return results
def run(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> Results:
"""
Command to run on the remote host via ssh
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@cwd current working directory to run the process in
@verbose_ssh: Enables verbose logging on ssh connections
@timeout: Timeout in seconds for the command to complete
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty,
)
def run_local(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
) -> Results:
"""
Command to run locally for each host in the group in parallel
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@cwd current working directory to run the process in
@extra_env environment variables to override when running the command
@timeout: Timeout in seconds for the command to complete
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
local=True,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
)
def run_function(
self, func: Callable[[Host], T], check: bool = True
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
@func the function to call
"""
threads = []
results: list[HostResult[T]] = [
HostResult(h, ClanError(f"No result set for thread {i}"))
for (i, h) in enumerate(self.hosts)
]
for i, host in enumerate(self.hosts):
thread = Thread(
target=_worker,
args=(func, host, results, i),
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if check:
self._reraise_errors(results)
return results
def filter(self, pred: Callable[[Host], bool]) -> "HostGroup":
"""Return a new Group with the results filtered by the predicate"""
return HostGroup(list(filter(pred, self.hosts)))

View File

@@ -418,14 +418,15 @@ def generate_vars(
) )
machine.flush_caches() machine.flush_caches()
except Exception as exc: except Exception as exc:
log.error(f"Failed to generate facts for {machine.name}: {exc}") # noqa machine.error(f"Failed to generate facts: {exc}")
errors += [exc] errors += [exc]
if len(errors) > 0: if len(errors) > 0:
msg = f"Failed to generate facts for {len(errors)} hosts. Check the logs above" msg = f"Failed to generate facts for {len(errors)} hosts. Check the logs above"
raise ClanError(msg) from errors[0] raise ClanError(msg) from errors[0]
if not was_regenerated: if not was_regenerated:
print("All vars are already up to date") machine.info("All vars are already up to date")
return was_regenerated return was_regenerated

View File

@@ -1,13 +1,12 @@
import io import io
import logging import logging
import os import os
import subprocess
import tarfile import tarfile
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import override from typing import override
from clan_cli.cmd import run from clan_cli.cmd import Log, run
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell from clan_cli.nix import nix_shell
@@ -135,8 +134,8 @@ class SecretStore(SecretStoreBase):
remote_hash = self.machine.target_host.run( remote_hash = self.machine.target_host.run(
# TODO get the path to the secrets from the machine # TODO get the path to the secrets from the machine
["cat", f"{self.machine.secret_vars_upload_directory}/.pass_info"], ["cat", f"{self.machine.secret_vars_upload_directory}/.pass_info"],
log=Log.STDERR,
check=False, check=False,
stdout=subprocess.PIPE,
).stdout.strip() ).stdout.strip()
if not remote_hash: if not remote_hash:

View File

@@ -4,6 +4,7 @@ import json
import logging import logging
import os import os
import subprocess import subprocess
import sys
import time import time
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@@ -339,7 +340,16 @@ def run_vm(
) as vm, ) as vm,
ThreadPoolExecutor(max_workers=1) as executor, ThreadPoolExecutor(max_workers=1) as executor,
): ):
future = executor.submit(handle_io, vm.process, input_bytes=None, log=Log.BOTH) future = executor.submit(
handle_io,
vm.process,
cmdlog=log,
prefix=f"[{vm_config.machine_name}] ",
stdout=sys.stdout.buffer,
stderr=sys.stderr.buffer,
input_bytes=None,
log=Log.BOTH,
)
args: list[str] = vm.process.args # type: ignore[assignment] args: list[str] = vm.process.args # type: ignore[assignment]
if runtime_config.command is not None: if runtime_config.command is not None:

View File

@@ -1,30 +1,19 @@
import argparse import argparse
import logging import logging
import os
import shlex import shlex
from clan_cli import create_parser from clan_cli import create_parser
from clan_cli.custom_logger import get_callers from clan_cli.custom_logger import print_trace
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def print_trace(msg: str) -> None:
trace_depth = int(os.environ.get("TRACE_DEPTH", "0"))
callers = get_callers(2, 2 + trace_depth)
if "run_no_stdout" in callers[0]:
callers = get_callers(3, 3 + trace_depth)
callers_str = "\n".join(f"{i+1}: {caller}" for i, caller in enumerate(callers))
log.debug(f"{msg} \nCallers: \n{callers_str}")
def run(args: list[str]) -> argparse.Namespace: def run(args: list[str]) -> argparse.Namespace:
parser = create_parser(prog="clan") parser = create_parser(prog="clan")
parsed = parser.parse_args(args) parsed = parser.parse_args(args)
cmd = shlex.join(["clan", *args]) cmd = shlex.join(["clan", *args])
print_trace(f"$ {cmd}") print_trace(f"$ {cmd}", log, "localhost")
if hasattr(parsed, "func"): if hasattr(parsed, "func"):
parsed.func(parsed) parsed.func(parsed)
return parsed return parsed

View File

@@ -58,6 +58,6 @@ def test_secrets_upload(
# the flake defines this path as the location where the sops key should be installed # the flake defines this path as the location where the sops key should be installed
sops_key = test_flake_with_core.path / "facts" / "key.txt" sops_key = test_flake_with_core.path / "facts" / "key.txt"
# breakpoint()
assert sops_key.exists() assert sops_key.exists()
assert sops_key.read_text() == age_keys[0].privkey assert sops_key.read_text() == age_keys[0].privkey

View File

@@ -1,5 +1,4 @@
import subprocess from clan_cli.cmd import Log
from clan_cli.ssh.host import Host from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup from clan_cli.ssh.host_group import HostGroup
@@ -8,21 +7,21 @@ hosts = HostGroup([Host("some_host")])
def test_run_environment() -> None: def test_run_environment() -> None:
p2 = hosts.run_local( p2 = hosts.run_local(
"echo $env_var", extra_env={"env_var": "true"}, stdout=subprocess.PIPE ["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR
) )
assert p2[0].result.stdout == "true\n" assert p2[0].result.stdout == "true\n"
p3 = hosts.run_local(["env"], extra_env={"env_var": "true"}, stdout=subprocess.PIPE) p3 = hosts.run_local(["env"], extra_env={"env_var": "true"}, log=Log.STDERR)
assert "env_var=true" in p3[0].result.stdout assert "env_var=true" in p3[0].result.stdout
def test_run_local() -> None: def test_run_local() -> None:
hosts.run_local("echo hello") hosts.run_local(["echo", "hello"])
def test_timeout() -> None: def test_timeout() -> None:
try: try:
hosts.run_local("sleep 10", timeout=0.01) hosts.run_local(["sleep", "10"], timeout=0.01)
except Exception: except Exception:
pass pass
else: else:
@@ -32,8 +31,8 @@ def test_timeout() -> None:
def test_run_function() -> None: def test_run_function() -> None:
def some_func(h: Host) -> bool: def some_func(h: Host) -> bool:
p = h.run_local("echo hello", stdout=subprocess.PIPE) par = h.run_local(["echo", "hello"], log=Log.STDERR)
return p.stdout == "hello\n" return par.stdout == "hello\n"
res = hosts.run_function(some_func) res = hosts.run_function(some_func)
assert res[0].result assert res[0].result
@@ -41,7 +40,7 @@ def test_run_function() -> None:
def test_run_exception() -> None: def test_run_exception() -> None:
try: try:
hosts.run_local("exit 1") hosts.run_local(["exit 1"], shell=True)
except Exception: except Exception:
pass pass
else: else:
@@ -51,7 +50,7 @@ def test_run_exception() -> None:
def test_run_function_exception() -> None: def test_run_function_exception() -> None:
def some_func(h: Host) -> None: def some_func(h: Host) -> None:
h.run_local("exit 1") h.run_local(["exit 1"], shell=True)
try: try:
hosts.run_function(some_func) hosts.run_function(some_func)
@@ -63,5 +62,5 @@ def test_run_function_exception() -> None:
def test_run_local_non_shell() -> None: def test_run_local_non_shell() -> None:
p2 = hosts.run_local(["echo", "1"], stdout=subprocess.PIPE) p2 = hosts.run_local(["echo", "1"], log=Log.STDERR)
assert p2[0].result.stdout == "1\n" assert p2[0].result.stdout == "1\n"

View File

@@ -1,6 +1,7 @@
import subprocess import subprocess
import pytest import pytest
from clan_cli.cmd import Log
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.ssh.host import Host from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup from clan_cli.ssh.host_group import HostGroup
@@ -22,27 +23,27 @@ def test_parse_ipv6() -> None:
def test_run(host_group: HostGroup) -> None: def test_run(host_group: HostGroup) -> None:
proc = host_group.run("echo hello", stdout=subprocess.PIPE) proc = host_group.run_local(["echo", "hello"], log=Log.STDERR)
assert proc[0].result.stdout == "hello\n" assert proc[0].result.stdout == "hello\n"
def test_run_environment(host_group: HostGroup) -> None: def test_run_environment(host_group: HostGroup) -> None:
p1 = host_group.run( p1 = host_group.run(
"echo $env_var", stdout=subprocess.PIPE, extra_env={"env_var": "true"} ["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR
) )
assert p1[0].result.stdout == "true\n" assert p1[0].result.stdout == "true\n"
p2 = host_group.run(["env"], stdout=subprocess.PIPE, extra_env={"env_var": "true"}) p2 = host_group.run(["env"], log=Log.STDERR, extra_env={"env_var": "true"})
assert "env_var=true" in p2[0].result.stdout assert "env_var=true" in p2[0].result.stdout
def test_run_no_shell(host_group: HostGroup) -> None: def test_run_no_shell(host_group: HostGroup) -> None:
proc = host_group.run(["echo", "$hello"], stdout=subprocess.PIPE) proc = host_group.run(["echo", "$hello"], log=Log.STDERR)
assert proc[0].result.stdout == "$hello\n" assert proc[0].result.stdout == "$hello\n"
def test_run_function(host_group: HostGroup) -> None: def test_run_function(host_group: HostGroup) -> None:
def some_func(h: Host) -> bool: def some_func(h: Host) -> bool:
p = h.run("echo hello", stdout=subprocess.PIPE) p = h.run(["echo", "hello"])
return p.stdout == "hello\n" return p.stdout == "hello\n"
res = host_group.run_function(some_func) res = host_group.run_function(some_func)
@@ -51,7 +52,7 @@ def test_run_function(host_group: HostGroup) -> None:
def test_timeout(host_group: HostGroup) -> None: def test_timeout(host_group: HostGroup) -> None:
try: try:
host_group.run_local("sleep 10", timeout=0.01) host_group.run_local(["sleep", "10"], timeout=0.01)
except Exception: except Exception:
pass pass
else: else:
@@ -60,11 +61,11 @@ def test_timeout(host_group: HostGroup) -> None:
def test_run_exception(host_group: HostGroup) -> None: def test_run_exception(host_group: HostGroup) -> None:
r = host_group.run("exit 1", check=False) r = host_group.run(["exit 1"], check=False, shell=True)
assert r[0].result.returncode == 1 assert r[0].result.returncode == 1
try: try:
host_group.run("exit 1") host_group.run(["exit 1"], shell=True)
except Exception: except Exception:
pass pass
else: else:
@@ -74,7 +75,7 @@ def test_run_exception(host_group: HostGroup) -> None:
def test_run_function_exception(host_group: HostGroup) -> None: def test_run_function_exception(host_group: HostGroup) -> None:
def some_func(h: Host) -> subprocess.CompletedProcess[str]: def some_func(h: Host) -> subprocess.CompletedProcess[str]:
return h.run_local("exit 1") return h.run_local(["exit 1"], shell=True)
try: try:
host_group.run_function(some_func) host_group.run_function(some_func)