Merge pull request 'clan-cli: Refactor ssh part 2, Refactor custom_logger' (#2473) from Qubasa/clan-core:Qubasa-main into main

This commit is contained in:
clan-bot
2024-11-22 21:29:10 +00:00
24 changed files with 714 additions and 1256 deletions

View File

@@ -1,8 +1,8 @@
import argparse
import json
import subprocess
from dataclasses import dataclass
from clan_cli.cmd import Log
from clan_cli.completions import (
add_dynamic_completer,
complete_backup_providers_for_machine,
@@ -23,7 +23,7 @@ def list_provider(machine: Machine, provider: str) -> list[Backup]:
backup_metadata = json.loads(machine.eval_nix("config.clan.core.backups"))
proc = machine.target_host.run(
[backup_metadata["providers"][provider]["list"]],
stdout=subprocess.PIPE,
log=Log.STDERR,
check=False,
)
if proc.returncode != 0:

View File

@@ -1,7 +1,7 @@
import argparse
import json
import subprocess
from clan_cli.cmd import Log
from clan_cli.completions import (
add_dynamic_completer,
complete_backup_providers_for_machine,
@@ -28,7 +28,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if pre_restore := backup_folders[service]["preRestoreCommand"]:
proc = machine.target_host.run(
[pre_restore],
stdout=subprocess.PIPE,
log=Log.STDERR,
extra_env=env,
)
if proc.returncode != 0:
@@ -37,7 +37,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
proc = machine.target_host.run(
[backup_metadata["providers"][provider]["restore"]],
stdout=subprocess.PIPE,
log=Log.STDERR,
extra_env=env,
)
if proc.returncode != 0:
@@ -47,7 +47,7 @@ def restore_service(machine: Machine, name: str, provider: str, service: str) ->
if post_restore := backup_folders[service]["postRestoreCommand"]:
proc = machine.target_host.run(
[post_restore],
stdout=subprocess.PIPE,
log=Log.STDERR,
extra_env=env,
)
if proc.returncode != 0:

View File

@@ -1,11 +1,12 @@
import contextlib
import logging
import math
import os
import select
import shlex
import signal
import subprocess
import sys
import time
import timeit
import weakref
from collections.abc import Iterator
@@ -14,12 +15,25 @@ from enum import Enum
from pathlib import Path
from typing import IO, Any
from clan_cli.errors import ClanError, indent_command
from clan_cli.custom_logger import print_trace
from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command
from .custom_logger import get_callers
from .errors import ClanCmdError, CmdOut
cmdlog = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class ClanCmdTimeoutError(ClanError):
timeout: float
def __init__(
self,
msg: str | None = None,
*,
description: str | None = None,
location: str | None = None,
timeout: float,
) -> None:
self.timeout = timeout
super().__init__(msg, description=description, location=location)
class Log(Enum):
@@ -30,14 +44,31 @@ class Log(Enum):
def handle_io(
process: subprocess.Popen, input_bytes: bytes | None, log: Log
process: subprocess.Popen,
log: Log,
cmdlog: logging.Logger,
prefix: str,
*,
input_bytes: bytes | None,
stdout: IO[bytes] | None,
stderr: IO[bytes] | None,
timeout: float = math.inf,
) -> tuple[str, str]:
rlist = [process.stdout, process.stderr]
wlist = [process.stdin] if input_bytes is not None else []
stdout_buf = b""
stderr_buf = b""
start = time.time()
# Loop until no more data is available
while len(rlist) != 0 or len(wlist) != 0:
# Check if the command has timed out
if time.time() - start > timeout:
msg = f"Command timed out after {timeout} seconds"
description = prefix
raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout)
# Wait for data to be available
readlist, writelist, _ = select.select(rlist, wlist, [], 0.1)
if len(readlist) == 0 and len(writelist) == 0:
if process.poll() is None:
@@ -45,6 +76,7 @@ def handle_io(
# Process has exited
break
# Function to handle file descriptors
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> bytes:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
@@ -53,19 +85,36 @@ def handle_io(
rlist.remove(fd)
return b""
#
# Process stdout
#
ret = handle_fd(process.stdout, readlist)
if ret and log in [Log.STDOUT, Log.BOTH]:
sys.stdout.buffer.write(ret)
sys.stdout.flush()
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
for line in lines:
cmdlog.info(line, extra={"command_prefix": prefix})
if ret and stdout:
stdout.write(ret)
stdout.flush()
#
# Process stderr
#
stdout_buf += ret
ret = handle_fd(process.stderr, readlist)
if ret and log in [Log.STDERR, Log.BOTH]:
sys.stderr.buffer.write(ret)
sys.stderr.flush()
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
for line in lines:
cmdlog.error(line, extra={"command_prefix": prefix})
if ret and stderr:
stderr.write(ret)
stderr.flush()
stderr_buf += ret
#
# Process stdin
#
if process.stdin in writelist:
if input_bytes:
try:
@@ -168,42 +217,35 @@ def run(
cmd: list[str],
*,
input: bytes | None = None, # noqa: A002
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
env: dict[str, str] | None = None,
cwd: Path | None = None,
log: Log = Log.STDERR,
logger: logging.Logger = cmdlog,
prefix: str | None = None,
check: bool = True,
error_msg: str | None = None,
needs_user_terminal: bool = False,
timeout: float = math.inf,
shell: bool = False,
) -> CmdOut:
if cwd is None:
cwd = Path.cwd()
def print_trace(msg: str) -> None:
trace_depth = int(os.environ.get("TRACE_DEPTH", "0"))
callers = get_callers(3, 4 + trace_depth)
if "run_no_stdout" in callers[0]:
callers = callers[1:]
else:
callers.pop()
if len(callers) == 1:
callers_str = f"Caller: {callers[0]}\n"
else:
callers_str = "\n".join(
f"{i+1}: {caller}" for i, caller in enumerate(callers)
)
callers_str = f"Callers:\n{callers_str}"
logger.debug(f"{msg} \n{callers_str}")
if prefix is None:
prefix = "localhost"
if input:
if any(not ch.isprintable() for ch in input.decode("ascii", "replace")):
filtered_input = "<<binary_blob>>"
else:
filtered_input = input.decode("ascii", "replace")
print_trace(f"$: echo '{filtered_input}' | {indent_command(cmd)}")
print_trace(
f"$: echo '{filtered_input}' | {indent_command(cmd)}", logger, prefix
)
elif logger.isEnabledFor(logging.DEBUG):
print_trace(f"$: {indent_command(cmd)}")
print_trace(f"$: {indent_command(cmd)}", logger, prefix)
start = timeit.default_timer()
with ExitStack() as stack:
@@ -217,6 +259,7 @@ def run(
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=not needs_user_terminal,
shell=shell,
)
)
@@ -226,7 +269,16 @@ def run(
else:
stack.enter_context(terminate_process_group(process))
stdout_buf, stderr_buf = handle_io(process, input, log)
stdout_buf, stderr_buf = handle_io(
process,
log,
prefix=prefix,
cmdlog=logger,
timeout=timeout,
input_bytes=input,
stdout=stdout,
stderr=stderr,
)
process.wait()
global TIME_TABLE
@@ -256,9 +308,12 @@ def run_no_stdout(
env: dict[str, str] | None = None,
cwd: Path | None = None,
log: Log = Log.STDERR,
logger: logging.Logger = cmdlog,
prefix: str | None = None,
check: bool = True,
error_msg: str | None = None,
needs_user_terminal: bool = False,
shell: bool = False,
) -> CmdOut:
"""
Like run, but automatically suppresses stdout, if not in DEBUG log level.
@@ -274,6 +329,8 @@ def run_no_stdout(
env=env,
log=log,
check=check,
prefix=prefix,
error_msg=error_msg,
needs_user_terminal=needs_user_terminal,
shell=shell,
)

View File

@@ -0,0 +1,2 @@
from .colors import * # noqa
from .csscolors import * # noqa

View File

@@ -0,0 +1,180 @@
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import re
from functools import partial
from .csscolors import parse_rgb
# ANSI color names. There is also a "default"
COLORS: tuple[str, ...] = (
"black",
"red",
"green",
"yellow",
"blue",
"magenta",
"cyan",
"white",
)
# ANSI style names
STYLES: tuple[str, ...] = (
"none",
"bold",
"faint",
"italic",
"underline",
"blink",
"blink2",
"negative",
"concealed",
"crossed",
)
def is_string(obj: str | bytes) -> bool:
"""
Is the given object a string?
"""
return isinstance(obj, str)
def _join(*values: int | str) -> str:
"""
Join a series of values with semicolons. The values
are either integers or strings, so stringify each for
good measure. Worth breaking out as its own function
because semicolon-joined lists are core to ANSI coding.
"""
return ";".join(str(v) for v in values)
def color_code(spec: str | int | tuple[int, int, int] | list, base: int) -> str:
"""
Workhorse of encoding a color. Give preference to named colors from
ANSI, then to specific numeric or tuple specs. If those don't work,
try looking up CSS color names or parsing CSS color specifications
(hex or rgb).
:param str|int|tuple|list spec: Unparsed color specification
:param int base: Either 30 or 40, signifying the base value
for color encoding (foreground and background respectively).
Low values are added directly to the base. Higher values use `
base + 8` (i.e. 38 or 48) then extended codes.
:returns: Discovered ANSI color encoding.
:rtype: str
:raises: ValueError if cannot parse the color spec.
"""
if isinstance(spec, str | bytes):
spec = spec.strip().lower()
if spec == "default":
return _join(base + 9)
if spec in COLORS:
return _join(base + COLORS.index(spec))
if isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
if isinstance(spec, tuple | list):
return _join(base + 8, 2, _join(*spec))
rgb = parse_rgb(str(spec))
# parse_rgb raises ValueError if cannot parse spec
# or returns an rgb tuple if it can
return _join(base + 8, 2, _join(*rgb))
def color(
s: str | None = None,
fg: str | int | tuple[int, int, int] | None = None,
bg: str | int | tuple[int, int, int] | None = None,
style: str | None = None,
reset: bool = True,
) -> str:
"""
Add ANSI colors and styles to a string.
:param str s: String to format.
:param str|int|tuple fg: Foreground color specification.
:param str|int|tuple bg: Background color specification.
:param str: Style names, separated by '+'
:returns: Formatted string.
:rtype: str (or unicode in Python 2, if s is unicode)
"""
codes: list[int | str] = []
if fg:
codes.append(color_code(fg, 30))
if bg:
codes.append(color_code(bg, 40))
if style:
for style_part in style.split("+"):
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
msg = f'Invalid style "{style_part}"'
raise ValueError(msg)
if not s:
s = ""
if codes:
if reset:
template = "\x1b[{0}m{1}\x1b[0m"
else:
template = "\x1b[{0}m{1}"
return template.format(_join(*codes), s)
return s
def strip_color(s: str) -> str:
"""
Remove ANSI color/style sequences from a string. The set of all possible
ANSI sequences is large, so does not try to strip every possible one. But
does strip some outliers seen not just in text generated by this module, but
by other ANSI colorizers in the wild. Those include `\x1b[K` (aka EL or
erase to end of line) and `\x1b[m`, a terse version of the more common
`\x1b[0m`.
"""
return re.sub("\x1b\\[(K|.*?m)", "", s)
def ansilen(s: str) -> int:
"""
Given a string with embedded ANSI codes, what would its
length be without those codes?
"""
return len(strip_color(s))
# Foreground color shortcuts
black = partial(color, fg="black")
red = partial(color, fg="red")
green = partial(color, fg="green")
yellow = partial(color, fg="yellow")
blue = partial(color, fg="blue")
magenta = partial(color, fg="magenta")
cyan = partial(color, fg="cyan")
white = partial(color, fg="white")
# Style shortcuts
bold = partial(color, style="bold")
none = partial(color, style="none")
faint = partial(color, style="faint")
italic = partial(color, style="italic")
underline = partial(color, style="underline")
blink = partial(color, style="blink")
blink2 = partial(color, style="blink2")
negative = partial(color, style="negative")
concealed = partial(color, style="concealed")
crossed = partial(color, style="crossed")

View File

@@ -0,0 +1,183 @@
"""
Map of CSS color names to RGB integer values.
"""
import re
css_colors: dict[str, tuple[int, int, int]] = {
"aliceblue": (240, 248, 255),
"antiquewhite": (250, 235, 215),
"aqua": (0, 255, 255),
"aquamarine": (127, 255, 212),
"azure": (240, 255, 255),
"beige": (245, 245, 220),
"bisque": (255, 228, 196),
"black": (0, 0, 0),
"blanchedalmond": (255, 235, 205),
"blue": (0, 0, 255),
"blueviolet": (138, 43, 226),
"brown": (165, 42, 42),
"burlywood": (222, 184, 135),
"cadetblue": (95, 158, 160),
"chartreuse": (127, 255, 0),
"chocolate": (210, 105, 30),
"coral": (255, 127, 80),
"cornflowerblue": (100, 149, 237),
"cornsilk": (255, 248, 220),
"crimson": (220, 20, 60),
"cyan": (0, 255, 255),
"darkblue": (0, 0, 139),
"darkcyan": (0, 139, 139),
"darkgoldenrod": (184, 134, 11),
"darkgray": (169, 169, 169),
"darkgreen": (0, 100, 0),
"darkgrey": (169, 169, 169),
"darkkhaki": (189, 183, 107),
"darkmagenta": (139, 0, 139),
"darkolivegreen": (85, 107, 47),
"darkorange": (255, 140, 0),
"darkorchid": (153, 50, 204),
"darkred": (139, 0, 0),
"darksalmon": (233, 150, 122),
"darkseagreen": (143, 188, 143),
"darkslateblue": (72, 61, 139),
"darkslategray": (47, 79, 79),
"darkslategrey": (47, 79, 79),
"darkturquoise": (0, 206, 209),
"darkviolet": (148, 0, 211),
"deeppink": (255, 20, 147),
"deepskyblue": (0, 191, 255),
"dimgray": (105, 105, 105),
"dimgrey": (105, 105, 105),
"dodgerblue": (30, 144, 255),
"firebrick": (178, 34, 34),
"floralwhite": (255, 250, 240),
"forestgreen": (34, 139, 34),
"fuchsia": (255, 0, 255),
"gainsboro": (220, 220, 220),
"ghostwhite": (248, 248, 255),
"gold": (255, 215, 0),
"goldenrod": (218, 165, 32),
"gray": (128, 128, 128),
"green": (0, 128, 0),
"greenyellow": (173, 255, 47),
"grey": (128, 128, 128),
"honeydew": (240, 255, 240),
"hotpink": (255, 105, 180),
"indianred": (205, 92, 92),
"indigo": (75, 0, 130),
"ivory": (255, 255, 240),
"khaki": (240, 230, 140),
"lavender": (230, 230, 250),
"lavenderblush": (255, 240, 245),
"lawngreen": (124, 252, 0),
"lemonchiffon": (255, 250, 205),
"lightblue": (173, 216, 230),
"lightcoral": (240, 128, 128),
"lightcyan": (224, 255, 255),
"lightgoldenrodyellow": (250, 250, 210),
"lightgray": (211, 211, 211),
"lightgreen": (144, 238, 144),
"lightgrey": (211, 211, 211),
"lightpink": (255, 182, 193),
"lightsalmon": (255, 160, 122),
"lightseagreen": (32, 178, 170),
"lightskyblue": (135, 206, 250),
"lightslategray": (119, 136, 153),
"lightslategrey": (119, 136, 153),
"lightsteelblue": (176, 196, 222),
"lightyellow": (255, 255, 224),
"lime": (0, 255, 0),
"limegreen": (50, 205, 50),
"linen": (250, 240, 230),
"magenta": (255, 0, 255),
"maroon": (128, 0, 0),
"mediumaquamarine": (102, 205, 170),
"mediumblue": (0, 0, 205),
"mediumorchid": (186, 85, 211),
"mediumpurple": (147, 112, 219),
"mediumseagreen": (60, 179, 113),
"mediumslateblue": (123, 104, 238),
"mediumspringgreen": (0, 250, 154),
"mediumturquoise": (72, 209, 204),
"mediumvioletred": (199, 21, 133),
"midnightblue": (25, 25, 112),
"mintcream": (245, 255, 250),
"mistyrose": (255, 228, 225),
"moccasin": (255, 228, 181),
"navajowhite": (255, 222, 173),
"navy": (0, 0, 128),
"oldlace": (253, 245, 230),
"olive": (128, 128, 0),
"olivedrab": (107, 142, 35),
"orange": (255, 165, 0),
"orangered": (255, 69, 0),
"orchid": (218, 112, 214),
"palegoldenrod": (238, 232, 170),
"palegreen": (152, 251, 152),
"paleturquoise": (175, 238, 238),
"palevioletred": (219, 112, 147),
"papayawhip": (255, 239, 213),
"peachpuff": (255, 218, 185),
"peru": (205, 133, 63),
"pink": (255, 192, 203),
"plum": (221, 160, 221),
"powderblue": (176, 224, 230),
"purple": (128, 0, 128),
"rebeccapurple": (102, 51, 153),
"red": (255, 0, 0),
"rosybrown": (188, 143, 143),
"royalblue": (65, 105, 225),
"saddlebrown": (139, 69, 19),
"salmon": (250, 128, 114),
"sandybrown": (244, 164, 96),
"seagreen": (46, 139, 87),
"seashell": (255, 245, 238),
"sienna": (160, 82, 45),
"silver": (192, 192, 192),
"skyblue": (135, 206, 235),
"slateblue": (106, 90, 205),
"slategray": (112, 128, 144),
"slategrey": (112, 128, 144),
"snow": (255, 250, 250),
"springgreen": (0, 255, 127),
"steelblue": (70, 130, 180),
"tan": (210, 180, 140),
"teal": (0, 128, 128),
"thistle": (216, 191, 216),
"tomato": (255, 99, 71),
"turquoise": (64, 224, 208),
"violet": (238, 130, 238),
"wheat": (245, 222, 179),
"white": (255, 255, 255),
"whitesmoke": (245, 245, 245),
"yellow": (255, 255, 0),
"yellowgreen": (154, 205, 50),
}
def parse_rgb(s: str) -> tuple[int, ...]:
s = s.strip().replace(" ", "").lower()
# simple lookup
rgb = css_colors.get(s)
if rgb is not None:
return rgb
# 6-digit hex
match = re.match("#([a-f0-9]{6})$", s)
if match:
core = match.group(1)
return tuple(int(core[i : i + 2], 16) for i in range(0, 6, 2))
# 3-digit hex
match = re.match("#([a-f0-9]{3})$", s)
if match:
return tuple(int(c * 2, 16) for c in match.group(1))
# rgb(x,y,z)
match = re.match(r"rgb\((\d+,\d+,\d+)\)", s)
if match:
return tuple(int(v) for v in match.group(1).split(","))
msg = f"Could not parse color '{s}'"
raise ValueError(msg)

View File

@@ -1,61 +1,74 @@
import inspect
import logging
import os
from collections.abc import Callable
import sys
from pathlib import Path
from typing import Any
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
green = "\u001b[32m"
blue = "\u001b[34m"
from clan_cli.colors import color, css_colors
# https://no-color.org
DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != ""
def get_formatter(color: str) -> Callable[[logging.LogRecord, bool], logging.Formatter]:
def myformatter(
record: logging.LogRecord, with_location: bool
) -> logging.Formatter:
reset = "\x1b[0m"
def _get_filepath(record: logging.LogRecord) -> Path:
try:
filepath = Path(record.pathname).resolve()
filepath = Path("~", filepath.relative_to(Path.home()))
except Exception:
filepath = Path(record.pathname)
if not with_location:
return logging.Formatter(f"{color}%(levelname)s{reset}: %(message)s")
return logging.Formatter(
f"{color}%(levelname)s{reset}: %(message)s\nLocation: {filepath}:%(lineno)d::%(funcName)s\n"
)
return myformatter
return filepath
FORMATTER = {
logging.DEBUG: get_formatter(blue),
logging.INFO: get_formatter(green),
logging.WARNING: get_formatter(yellow),
logging.ERROR: get_formatter(red),
logging.CRITICAL: get_formatter(bold_red),
}
class PrefixFormatter(logging.Formatter):
"""
print errors in red and warnings in yellow
"""
def __init__(self, trace_prints: bool = False, default_prefix: str = "") -> None:
self.default_prefix = default_prefix
self.trace_prints = trace_prints
class CustomFormatter(logging.Formatter):
def __init__(self, log_locations: bool) -> None:
super().__init__()
self.log_locations = log_locations
self.hostnames: list[str] = []
self.hostname_color_offset = 1 # first host shouldn't get aggressive red
def format(self, record: logging.LogRecord) -> str:
return FORMATTER[record.levelno](record, self.log_locations).format(record)
filepath = _get_filepath(record)
if record.levelno == logging.DEBUG:
color_str = "blue"
elif record.levelno == logging.ERROR:
color_str = "red"
elif record.levelno == logging.WARNING:
color_str = "yellow"
else:
color_str = None
class ThreadFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
return FORMATTER[record.levelno](record, False).format(record)
command_prefix = getattr(record, "command_prefix", self.default_prefix)
if not DISABLE_COLOR:
prefix_color = self.hostname_colorcode(command_prefix)
format_str = color(f"[{command_prefix}]", fg=prefix_color)
format_str += color(" %(message)s", fg=color_str)
else:
format_str = f"[{command_prefix}] %(message)s"
if self.trace_prints:
format_str += f"\nSource: {filepath}:%(lineno)d::%(funcName)s\n"
return logging.Formatter(format_str).format(record)
def hostname_colorcode(self, hostname: str) -> tuple[int, int, int]:
try:
index = self.hostnames.index(hostname)
except ValueError:
self.hostnames += [hostname]
index = self.hostnames.index(hostname)
coloroffset = (index + self.hostname_color_offset) % len(css_colors)
colorcode = list(css_colors.values())[coloroffset]
return colorcode
def get_callers(start: int = 2, end: int = 2) -> list[str]:
@@ -103,7 +116,28 @@ def get_callers(start: int = 2, end: int = 2) -> list[str]:
return callers
def setup_logging(level: Any, root_log_name: str = __name__.split(".")[0]) -> None:
def print_trace(msg: str, logger: logging.Logger, prefix: str) -> None:
trace_depth = int(os.environ.get("TRACE_DEPTH", "0"))
callers = get_callers(3, 4 + trace_depth)
if "run_no_stdout" in callers[0]:
callers = callers[1:]
else:
callers.pop()
if len(callers) == 1:
callers_str = f"Caller: {callers[0]}\n"
else:
callers_str = "\n".join(f"{i+1}: {caller}" for i, caller in enumerate(callers))
callers_str = f"Callers:\n{callers_str}"
logger.debug(f"{msg} \n{callers_str}", extra={"command_prefix": prefix})
def setup_logging(
level: Any,
root_log_name: str = __name__.split(".")[0],
default_prefix: str = "clan",
) -> None:
# Get the root logger and set its level
main_logger = logging.getLogger(root_log_name)
main_logger.setLevel(level)
@@ -113,10 +147,6 @@ def setup_logging(level: Any, root_log_name: str = __name__.split(".")[0]) -> No
# Create and add your custom handler
default_handler.setLevel(level)
trace_depth = bool(int(os.environ.get("TRACE_DEPTH", "0")))
default_handler.setFormatter(CustomFormatter(trace_depth))
trace_prints = bool(int(os.environ.get("TRACE_PRINT", "0")))
default_handler.setFormatter(PrefixFormatter(trace_prints, default_prefix))
main_logger.addHandler(default_handler)
# Set logging level for other modules used by this module
logging.getLogger("asyncio").setLevel(logging.INFO)
logging.getLogger("httpx").setLevel(level=logging.WARNING)

View File

@@ -74,7 +74,7 @@ def indent_command(command_list: list[str]) -> str:
# Indent after the next argument
formatted_command.append(" ")
i += 1
formatted_command.append(shlex.quote(command_list[i]))
formatted_command.append(command_list[i])
if i < len(command_list) - 1:
# Add line continuation only if it's not the last argument

View File

@@ -204,7 +204,7 @@ def generate_facts(
machine, service, regenerate, tmpdir, prompt
)
except (OSError, ClanError):
log.exception(f"Failed to generate facts for {machine.name}")
machine.error("Failed to generate facts")
errors += 1
if errors > 0:
msg = (
@@ -213,7 +213,7 @@ def generate_facts(
raise ClanError(msg)
if not was_regenerated:
print("All secrets and facts are already up to date")
machine.info("All secrets and facts are already up to date")
return was_regenerated

View File

@@ -3,6 +3,7 @@ import subprocess
from pathlib import Path
from typing import override
from clan_cli.cmd import Log
from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell
@@ -97,8 +98,8 @@ class SecretStore(SecretStoreBase):
remote_hash = self.machine.target_host.run(
# TODO get the path to the secrets from the machine
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
log=Log.STDERR,
check=False,
stdout=subprocess.PIPE,
).stdout.strip()
if not remote_hash:

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env bash
jsonSchema=$(nix build .#schemas.inventory-schema-abstract --print-out-paths)/schema.json
nix run .#classgen "$jsonSchema" "$PKG_ROOT/clan_cli/inventory/classes.py"
nix run .#classgen "$jsonSchema" "$PKG_ROOT/clan_cli/inventory/classes.py" -- --stop-at "Service"

View File

@@ -48,6 +48,18 @@ class Machine:
def __repr__(self) -> str:
return str(self)
def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
kwargs.update({"extra": {"command_prefix": self.name}})
log.debug(msg, *args, **kwargs)
def info(self, msg: str, *args: Any, **kwargs: Any) -> None:
kwargs.update({"extra": {"command_prefix": self.name}})
log.info(msg, *args, **kwargs)
def error(self, msg: str, *args: Any, **kwargs: Any) -> None:
kwargs.update({"extra": {"command_prefix": self.name}})
log.error(msg, *args, **kwargs)
@property
def system(self) -> str:
# We filter out function attributes because they are not serializable.

View File

@@ -64,7 +64,7 @@ def upload_sources(machine: Machine, always_upload_source: bool = False) -> str:
path,
]
)
run(cmd, env=env, error_msg="failed to upload sources")
run(cmd, env=env, error_msg="failed to upload sources", prefix=machine.name)
return path
# Slow path: we need to upload all sources to the remote machine
@@ -78,7 +78,6 @@ def upload_sources(machine: Machine, always_upload_source: bool = False) -> str:
flake_url,
]
)
log.info("run %s", shlex.join(cmd))
proc = run(cmd, env=env, error_msg="failed to upload sources")
try:

View File

@@ -1,47 +1,28 @@
# Adapted from https://github.com/numtide/deploykit
import fcntl
import logging
import math
import os
import select
import shlex
import subprocess
import tarfile
import time
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path
from shlex import quote
from tempfile import TemporaryDirectory
from typing import IO, Any
from clan_cli.cmd import Log, terminate_process_group
from clan_cli.cmd import Log
from clan_cli.cmd import run as local_run
from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.logger import cmdlog
FILE = None | int
cmdlog = logging.getLogger(__name__)
# Seconds until a message is printed when _run produces no output.
NO_OUTPUT_TIMEOUT = 20
@contextmanager
def _pipe() -> Iterator[tuple[IO[str], IO[str]]]:
(pipe_r, pipe_w) = os.pipe()
read_end = os.fdopen(pipe_r, "r")
write_end = os.fdopen(pipe_w, "w")
try:
fl = fcntl.fcntl(read_end, fcntl.F_GETFL)
fcntl.fcntl(read_end, fcntl.F_SETFL, fl | os.O_NONBLOCK)
yield (read_end, write_end)
finally:
read_end.close()
write_end.close()
class Host:
def __init__(
self,
@@ -101,196 +82,56 @@ class Host:
host = f"[{host}]"
return f"{self.user or 'root'}@{host}"
def _prefix_output(
self,
displayed_cmd: str,
print_std_fd: IO[str] | None,
print_err_fd: IO[str] | None,
stdout: IO[str] | None,
stderr: IO[str] | None,
timeout: float = math.inf,
) -> tuple[str, str]:
rlist = []
if print_std_fd is not None:
rlist.append(print_std_fd)
if print_err_fd is not None:
rlist.append(print_err_fd)
if stdout is not None:
rlist.append(stdout)
if stderr is not None:
rlist.append(stderr)
print_std_buf = ""
print_err_buf = ""
stdout_buf = ""
stderr_buf = ""
start = time.time()
last_output = time.time()
while len(rlist) != 0:
readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False
) -> tuple[float, str]:
read = os.read(print_fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(print_fd)
print_buf += read.decode("utf-8")
if (read == b"" and len(print_buf) != 0) or "\n" in print_buf:
# print and empty the print_buf, if the stream is draining,
# but there is still something in the buffer or on newline.
lines = print_buf.rstrip("\n").split("\n")
for line in lines:
if not is_err:
cmdlog.info(
line, extra={"command_prefix": self.command_prefix}
)
else:
cmdlog.error(
line, extra={"command_prefix": self.command_prefix}
)
print_buf = ""
last_output = time.time()
return (last_output, print_buf)
if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False
)
if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True
)
now = time.time()
elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix},
)
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(fd)
else:
return read.decode("utf-8")
return ""
stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout:
break
return stdout_buf, stderr_buf
def _run(
self,
cmd: list[str],
displayed_cmd: str,
shell: bool,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
*,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
input: bytes | None = None, # noqa: A002
env: dict[str, str] | None = None,
cwd: Path | None = None,
log: Log = Log.BOTH,
check: bool = True,
timeout: float = math.inf,
error_msg: str | None = None,
needs_user_terminal: bool = False,
shell: bool = False,
timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]:
if extra_env is None:
extra_env = {}
with ExitStack() as stack:
read_std_fd, write_std_fd = (None, None)
read_err_fd, write_err_fd = (None, None)
if stdout is None or stderr is None:
read_std_fd, write_std_fd = stack.enter_context(_pipe())
read_err_fd, write_err_fd = stack.enter_context(_pipe())
if stdout is None:
stdout_read = None
stdout_write = write_std_fd
elif stdout == subprocess.PIPE:
stdout_read, stdout_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stdout parameter: {stdout}"
raise ClanError(msg)
if stderr is None:
stderr_read = None
stderr_write = write_err_fd
elif stderr == subprocess.PIPE:
stderr_read, stderr_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stderr parameter: {stderr}"
raise ClanError(msg)
env = os.environ.copy()
env.update(extra_env)
with subprocess.Popen(
res = local_run(
cmd,
text=True,
shell=shell,
stdout=stdout_write,
stderr=stderr_write,
stdout=stdout,
prefix=self.command_prefix,
timeout=timeout,
stderr=stderr,
input=input,
env=env,
cwd=cwd,
start_new_session=not needs_user_terminal,
) as p:
if not needs_user_terminal:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None:
write_std_fd.close()
if write_err_fd is not None:
write_err_fd.close()
if stdout == subprocess.PIPE:
assert stdout_write is not None
stdout_write.close()
if stderr == subprocess.PIPE:
assert stderr_write is not None
stderr_write.close()
start = time.time()
stdout_data, stderr_data = self._prefix_output(
displayed_cmd,
read_std_fd,
read_err_fd,
stdout_read,
stderr_read,
timeout,
)
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
if ret != 0:
if check:
msg = f"Command {shlex.join(cmd)} failed with return code {ret}"
raise ClanError(msg)
cmdlog.warning(
f"[Command failed: {ret}] {displayed_cmd}",
extra={"command_prefix": self.command_prefix},
log=log,
logger=cmdlog,
check=check,
error_msg=error_msg,
needs_user_terminal=needs_user_terminal,
)
return subprocess.CompletedProcess(
cmd, ret, stdout=stdout_data, stderr=stderr_data
args=res.command_list,
returncode=res.returncode,
stdout=res.stdout,
stderr=res.stderr,
)
msg = "unreachable"
raise RuntimeError(msg)
def run_local(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
shell: bool = False,
log: Log = Log.BOTH,
) -> subprocess.CompletedProcess[str]:
"""
Command to run locally for the host
@@ -304,38 +145,38 @@ class Host:
@return subprocess.CompletedProcess result of the command
"""
if extra_env is None:
extra_env = {}
shell = False
if isinstance(cmd, str):
cmd = [cmd]
shell = True
env = os.environ.copy()
if extra_env:
env.update(extra_env)
displayed_cmd = " ".join(cmd)
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
return self._run(
cmd,
displayed_cmd,
shell=shell,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
env=env,
cwd=cwd,
check=check,
timeout=timeout,
log=log,
)
def run(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
become_root: bool = False,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
cwd: None | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> subprocess.CompletedProcess[str]:
"""
Command to run on the host via ssh
@@ -353,48 +194,50 @@ class Host:
"""
if extra_env is None:
extra_env = {}
# If we are not root and we need to become root, prepend sudo
sudo = ""
if become_root and self.user != "root":
sudo = "sudo -- "
# Quote all added environment variables
env_vars = []
for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
# Build a pretty command for logging
displayed_cmd = ""
export_cmd = ""
if env_vars:
export_cmd = f"export {' '.join(env_vars)}; "
displayed_cmd += export_cmd
if isinstance(cmd, list):
displayed_cmd += " ".join(cmd)
else:
displayed_cmd += cmd
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
# Build the ssh command
bash_cmd = export_cmd
bash_args = []
if isinstance(cmd, list):
bash_cmd += 'exec "$@"'
bash_args += cmd
if shell:
bash_cmd += " ".join(cmd)
else:
bash_cmd += cmd
bash_cmd += 'exec "$@"'
# FIXME we assume bash to be present here? Should be documented...
ssh_cmd = [
*self.ssh_cmd(verbose_ssh=verbose_ssh, tty=tty),
"--",
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, bash_args))}",
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, cmd))}",
]
# Run the ssh command
return self._run(
ssh_cmd,
displayed_cmd,
shell=False,
stdout=stdout,
stderr=stderr,
log=log,
cwd=cwd,
check=check,
timeout=timeout,
# all ssh commands can potentially ask for a password
needs_user_terminal=True,
needs_user_terminal=True, # ssh asks for a password
)
def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]:
@@ -464,13 +307,19 @@ class Host:
"tar",
"-C",
str(remote_dest),
"-xvzf",
"-xzf",
"-",
]
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f:
local_run(cmd, input=f.read(), log=Log.BOTH, needs_user_terminal=True)
local_run(
cmd,
input=f.read(),
log=Log.BOTH,
needs_user_terminal=True,
prefix=self.command_prefix,
)
@property
def ssh_cmd_opts(

View File

@@ -1,15 +1,18 @@
import logging
import math
from collections.abc import Callable
from pathlib import Path
from threading import Thread
from typing import Any
from typing import IO, Any
from clan_cli.cmd import Log
from clan_cli.errors import ClanError
from clan_cli.ssh import T
from clan_cli.ssh.host import FILE, Host
from clan_cli.ssh.logger import cmdlog
from clan_cli.ssh.host import Host
from clan_cli.ssh.results import HostResult, Results
cmdlog = logging.getLogger(__name__)
def _worker(
func: Callable[[Host], T],
@@ -35,17 +38,19 @@ class HostGroup:
def _run_local(
self,
cmd: str | list[str],
cmd: list[str],
host: Host,
results: Results,
stdout: FILE = None,
stderr: FILE = None,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
cwd: None | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
shell: bool = False,
tty: bool = False,
log: Log = Log.BOTH,
) -> None:
if extra_env is None:
extra_env = {}
@@ -58,6 +63,8 @@ class HostGroup:
cwd=cwd,
check=check,
timeout=timeout,
shell=shell,
log=log,
)
results.append(HostResult(host, proc))
except Exception as e:
@@ -65,17 +72,19 @@ class HostGroup:
def _run_remote(
self,
cmd: str | list[str],
cmd: list[str],
host: Host,
results: Results,
stdout: FILE = None,
stderr: FILE = None,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> None:
if cwd is not None:
msg = "cwd is not supported for remote commands"
@@ -93,6 +102,8 @@ class HostGroup:
verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty,
shell=shell,
log=log,
)
results.append(HostResult(host, proc))
except Exception as e:
@@ -114,16 +125,18 @@ class HostGroup:
def _run(
self,
cmd: str | list[str],
cmd: list[str],
local: bool = False,
stdout: FILE = None,
stderr: FILE = None,
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results:
if extra_env is None:
extra_env = {}
@@ -145,6 +158,8 @@ class HostGroup:
"timeout": timeout,
"verbose_ssh": verbose_ssh,
"tty": tty,
"shell": shell,
"log": log,
},
)
thread.start()
@@ -160,15 +175,17 @@ class HostGroup:
def run(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
log: Log = Log.BOTH,
shell: bool = False,
) -> Results:
"""
Command to run on the remote host via ssh
@@ -184,6 +201,7 @@ class HostGroup:
extra_env = {}
return self._run(
cmd,
shell=shell,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
@@ -192,17 +210,20 @@ class HostGroup:
verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty,
log=log,
)
def run_local(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
cmd: list[str],
stdout: IO[bytes] | None = None,
stderr: IO[bytes] | None = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
shell: bool = False,
log: Log = Log.BOTH,
) -> Results:
"""
Command to run locally for each host in the group in parallel
@@ -226,6 +247,8 @@ class HostGroup:
cwd=cwd,
check=check,
timeout=timeout,
shell=shell,
log=log,
)
def run_function(

View File

@@ -1,87 +0,0 @@
# Adapted from https://github.com/numtide/deploykit
import logging
import os
import sys
# https://no-color.org
DISABLE_COLOR = not sys.stderr.isatty() or os.environ.get("NO_COLOR", "") != ""
def ansi_color(color: int) -> str:
return f"\x1b[{color}m"
class CommandFormatter(logging.Formatter):
"""
print errors in red and warnings in yellow
"""
def __init__(self) -> None:
super().__init__(
"%(prefix_color)s[%(command_prefix)s]%(color_reset)s %(color)s%(message)s%(color_reset)s"
)
self.hostnames: list[str] = []
self.hostname_color_offset = 1 # first host shouldn't get aggressive red
def format(self, record: logging.LogRecord) -> str:
colorcode = 0
if record.levelno == logging.ERROR:
colorcode = 31 # red
if record.levelno == logging.WARNING:
colorcode = 33 # yellow
color, prefix_color, color_reset = "", "", ""
if not DISABLE_COLOR:
command_prefix = getattr(record, "command_prefix", "")
color = ansi_color(colorcode)
prefix_color = ansi_color(self.hostname_colorcode(command_prefix))
color_reset = "\x1b[0m"
record.color = color
record.prefix_color = prefix_color
record.color_reset = color_reset
return super().format(record)
def hostname_colorcode(self, hostname: str) -> int:
try:
index = self.hostnames.index(hostname)
except ValueError:
self.hostnames += [hostname]
index = self.hostnames.index(hostname)
return 31 + (index + self.hostname_color_offset) % 7
def setup_loggers() -> tuple[logging.Logger, logging.Logger]:
# If we use the default logger here (logging.error etc) or a logger called
# "deploykit", then cmdlog messages are also posted on the default logger.
# To avoid this message duplication, we set up a main and command logger
# and use a "deploykit" main logger.
kitlog = logging.getLogger("deploykit.main")
kitlog.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(logging.Formatter())
kitlog.addHandler(ch)
# use specific logger for command outputs
cmdlog = logging.getLogger("deploykit.command")
cmdlog.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(CommandFormatter())
cmdlog.addHandler(ch)
return (kitlog, cmdlog)
# loggers for: general deploykit, command output
kitlog, cmdlog = setup_loggers()
info = kitlog.info
warn = kitlog.warning
error = kitlog.error

View File

@@ -1,790 +0,0 @@
# Adapted from https://github.com/numtide/deploykit
import fcntl
import math
import os
import select
import shlex
import subprocess
import tarfile
import time
from collections.abc import Callable, Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path
from shlex import quote
from tempfile import TemporaryDirectory
from threading import Thread
from typing import IO, Any, Generic, TypeVar
from clan_cli.cmd import Log, terminate_process_group
from clan_cli.cmd import run as local_run
from clan_cli.errors import ClanError
from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.logger import cmdlog
FILE = None | int
# Seconds until a message is printed when _run produces no output.
NO_OUTPUT_TIMEOUT = 20
@contextmanager
def _pipe() -> Iterator[tuple[IO[str], IO[str]]]:
(pipe_r, pipe_w) = os.pipe()
read_end = os.fdopen(pipe_r, "r")
write_end = os.fdopen(pipe_w, "w")
try:
fl = fcntl.fcntl(read_end, fcntl.F_GETFL)
fcntl.fcntl(read_end, fcntl.F_SETFL, fl | os.O_NONBLOCK)
yield (read_end, write_end)
finally:
read_end.close()
write_end.close()
class Host:
def __init__(
self,
host: str,
user: str | None = None,
port: int | None = None,
key: str | None = None,
forward_agent: bool = False,
command_prefix: str | None = None,
host_key_check: HostKeyCheck = HostKeyCheck.ASK,
meta: dict[str, Any] | None = None,
verbose_ssh: bool = False,
ssh_options: dict[str, str] | None = None,
) -> None:
"""
Creates a Host
@host the hostname to connect to via ssh
@port the port to connect to via ssh
@forward_agent: whether to forward ssh agent
@command_prefix: string to prefix each line of the command output with, defaults to host
@host_key_check: whether to check ssh host keys
@verbose_ssh: Enables verbose logging on ssh connections
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
"""
if ssh_options is None:
ssh_options = {}
if meta is None:
meta = {}
self.host = host
self.user = user
self.port = port
self.key = key
if command_prefix:
self.command_prefix = command_prefix
else:
self.command_prefix = host
self.forward_agent = forward_agent
self.host_key_check = host_key_check
self.meta = meta
self.verbose_ssh = verbose_ssh
self._ssh_options = ssh_options
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"{self.user}@{self.host}" + str(self.port if self.port else "")
@property
def target(self) -> str:
return f"{self.user or 'root'}@{self.host}"
@property
def target_for_rsync(self) -> str:
host = self.host
if ":" in host:
host = f"[{host}]"
return f"{self.user or 'root'}@{host}"
def _prefix_output(
self,
displayed_cmd: str,
print_std_fd: IO[str] | None,
print_err_fd: IO[str] | None,
stdout: IO[str] | None,
stderr: IO[str] | None,
timeout: float = math.inf,
) -> tuple[str, str]:
rlist = []
if print_std_fd is not None:
rlist.append(print_std_fd)
if print_err_fd is not None:
rlist.append(print_err_fd)
if stdout is not None:
rlist.append(stdout)
if stderr is not None:
rlist.append(stderr)
print_std_buf = ""
print_err_buf = ""
stdout_buf = ""
stderr_buf = ""
start = time.time()
last_output = time.time()
while len(rlist) != 0:
readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False
) -> tuple[float, str]:
read = os.read(print_fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(print_fd)
print_buf += read.decode("utf-8")
if (read == b"" and len(print_buf) != 0) or "\n" in print_buf:
# print and empty the print_buf, if the stream is draining,
# but there is still something in the buffer or on newline.
lines = print_buf.rstrip("\n").split("\n")
for line in lines:
if not is_err:
cmdlog.info(
line, extra={"command_prefix": self.command_prefix}
)
else:
cmdlog.error(
line, extra={"command_prefix": self.command_prefix}
)
print_buf = ""
last_output = time.time()
return (last_output, print_buf)
if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False
)
if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True
)
now = time.time()
elapsed = now - start
if now - last_output > NO_OUTPUT_TIMEOUT:
elapsed_msg = time.strftime("%H:%M:%S", time.gmtime(elapsed))
cmdlog.warning(
f"still waiting for '{displayed_cmd}' to finish... ({elapsed_msg} elapsed)",
extra={"command_prefix": self.command_prefix},
)
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(fd)
else:
return read.decode("utf-8")
return ""
stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout:
break
return stdout_buf, stderr_buf
def _run(
self,
cmd: list[str],
displayed_cmd: str,
shell: bool,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
needs_user_terminal: bool = False,
) -> subprocess.CompletedProcess[str]:
if extra_env is None:
extra_env = {}
with ExitStack() as stack:
read_std_fd, write_std_fd = (None, None)
read_err_fd, write_err_fd = (None, None)
if stdout is None or stderr is None:
read_std_fd, write_std_fd = stack.enter_context(_pipe())
read_err_fd, write_err_fd = stack.enter_context(_pipe())
if stdout is None:
stdout_read = None
stdout_write = write_std_fd
elif stdout == subprocess.PIPE:
stdout_read, stdout_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stdout parameter: {stdout}"
raise ClanError(msg)
if stderr is None:
stderr_read = None
stderr_write = write_err_fd
elif stderr == subprocess.PIPE:
stderr_read, stderr_write = stack.enter_context(_pipe())
else:
msg = f"unsupported value for stderr parameter: {stderr}"
raise ClanError(msg)
env = os.environ.copy()
env.update(extra_env)
with subprocess.Popen(
cmd,
text=True,
shell=shell,
stdout=stdout_write,
stderr=stderr_write,
env=env,
cwd=cwd,
start_new_session=not needs_user_terminal,
) as p:
if not needs_user_terminal:
stack.enter_context(terminate_process_group(p))
if write_std_fd is not None:
write_std_fd.close()
if write_err_fd is not None:
write_err_fd.close()
if stdout == subprocess.PIPE:
assert stdout_write is not None
stdout_write.close()
if stderr == subprocess.PIPE:
assert stderr_write is not None
stderr_write.close()
start = time.time()
stdout_data, stderr_data = self._prefix_output(
displayed_cmd,
read_std_fd,
read_err_fd,
stdout_read,
stderr_read,
timeout,
)
ret = p.wait(timeout=max(0, timeout - (time.time() - start)))
if ret != 0:
if check:
msg = f"Command {shlex.join(cmd)} failed with return code {ret}"
raise ClanError(msg)
cmdlog.warning(
f"[Command failed: {ret}] {displayed_cmd}",
extra={"command_prefix": self.command_prefix},
)
return subprocess.CompletedProcess(
cmd, ret, stdout=stdout_data, stderr=stderr_data
)
msg = "unreachable"
raise RuntimeError(msg)
def run_local(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]:
"""
Command to run locally for the host
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocess.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@extra_env environment variables to override when running the command
@cwd current working directory to run the process in
@timeout: Timeout in seconds for the command to complete
@return subprocess.CompletedProcess result of the command
"""
if extra_env is None:
extra_env = {}
shell = False
if isinstance(cmd, str):
cmd = [cmd]
shell = True
displayed_cmd = " ".join(cmd)
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
return self._run(
cmd,
displayed_cmd,
shell=shell,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
)
def run(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
become_root: bool = False,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
) -> subprocess.CompletedProcess[str]:
"""
Command to run on the host via ssh
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@become_root if the ssh_user is not root than sudo is prepended
@extra_env environment variables to override when running the command
@cwd current working directory to run the process in
@verbose_ssh: Enables verbose logging on ssh connections
@timeout: Timeout in seconds for the command to complete
@return subprocess.CompletedProcess result of the ssh command
"""
if extra_env is None:
extra_env = {}
sudo = ""
if become_root and self.user != "root":
sudo = "sudo -- "
env_vars = []
for k, v in extra_env.items():
env_vars.append(f"{shlex.quote(k)}={shlex.quote(v)}")
displayed_cmd = ""
export_cmd = ""
if env_vars:
export_cmd = f"export {' '.join(env_vars)}; "
displayed_cmd += export_cmd
if isinstance(cmd, list):
displayed_cmd += " ".join(cmd)
else:
displayed_cmd += cmd
cmdlog.info(f"$ {displayed_cmd}", extra={"command_prefix": self.command_prefix})
bash_cmd = export_cmd
bash_args = []
if isinstance(cmd, list):
bash_cmd += 'exec "$@"'
bash_args += cmd
else:
bash_cmd += cmd
# FIXME we assume bash to be present here? Should be documented...
ssh_cmd = [
*self.ssh_cmd(verbose_ssh=verbose_ssh, tty=tty),
"--",
f"{sudo}bash -c {quote(bash_cmd)} -- {' '.join(map(quote, bash_args))}",
]
return self._run(
ssh_cmd,
displayed_cmd,
shell=False,
stdout=stdout,
stderr=stderr,
cwd=cwd,
check=check,
timeout=timeout,
# all ssh commands can potentially ask for a password
needs_user_terminal=True,
)
def nix_ssh_env(self, env: dict[str, str] | None) -> dict[str, str]:
if env is None:
env = {}
env["NIX_SSHOPTS"] = " ".join(self.ssh_cmd_opts)
return env
def upload(
self,
local_src: Path, # must be a directory
remote_dest: Path, # must be a directory
file_user: str = "root",
file_group: str = "root",
dir_mode: int = 0o700,
file_mode: int = 0o400,
) -> None:
# check if the remote destination is a directory (no suffix)
if remote_dest.suffix:
msg = "Only directories are allowed"
raise ClanError(msg)
if not local_src.is_dir():
msg = "Only directories are allowed"
raise ClanError(msg)
# Create the tarball from the temporary directory
with TemporaryDirectory(prefix="facts-upload-") as tardir:
tar_path = Path(tardir) / "upload.tar.gz"
# We set the permissions of the files and directories in the tarball to read only and owned by root
# As first uploading the tarball and then changing the permissions can lead an attacker to
# do a race condition attack
with tarfile.open(str(tar_path), "w:gz") as tar:
for root, dirs, files in local_src.walk():
for mdir in dirs:
dir_path = Path(root) / mdir
tarinfo = tar.gettarinfo(
dir_path, arcname=str(dir_path.relative_to(str(local_src)))
)
tarinfo.mode = dir_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
tar.addfile(tarinfo)
for file in files:
file_path = Path(root) / file
tarinfo = tar.gettarinfo(
file_path,
arcname=str(file_path.relative_to(str(local_src))),
)
tarinfo.mode = file_mode
tarinfo.uname = file_user
tarinfo.gname = file_group
with file_path.open("rb") as f:
tar.addfile(tarinfo, f)
cmd = [
*self.ssh_cmd(),
"rm",
"-r",
str(remote_dest),
";",
"mkdir",
f"--mode={dir_mode:o}",
"-p",
str(remote_dest),
"&&",
"tar",
"-C",
str(remote_dest),
"-xvzf",
"-",
]
# TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory.
with tar_path.open("rb") as f:
local_run(cmd, input=f.read(), log=Log.BOTH, needs_user_terminal=True)
@property
def ssh_cmd_opts(
self,
) -> list[str]:
ssh_opts = ["-A"] if self.forward_agent else []
for k, v in self._ssh_options.items():
ssh_opts.extend(["-o", f"{k}={shlex.quote(v)}"])
ssh_opts.extend(self.host_key_check.to_ssh_opt())
return ssh_opts
def ssh_cmd(
self,
verbose_ssh: bool = False,
tty: bool = False,
) -> list[str]:
ssh_opts = self.ssh_cmd_opts
if verbose_ssh or self.verbose_ssh:
ssh_opts.extend(["-v"])
if tty:
ssh_opts.extend(["-t"])
if self.port:
ssh_opts.extend(["-p", str(self.port)])
if self.key:
ssh_opts.extend(["-i", self.key])
return [
"ssh",
self.target,
*ssh_opts,
]
T = TypeVar("T")
class HostResult(Generic[T]):
def __init__(self, host: Host, result: T | Exception) -> None:
self.host = host
self._result = result
@property
def error(self) -> Exception | None:
"""
Returns an error if the command failed
"""
if isinstance(self._result, Exception):
return self._result
return None
@property
def result(self) -> T:
"""
Unwrap the result
"""
if isinstance(self._result, Exception):
raise self._result
return self._result
Results = list[HostResult[subprocess.CompletedProcess[str]]]
def _worker(
func: Callable[[Host], T],
host: Host,
results: list[HostResult[T]],
idx: int,
) -> None:
try:
results[idx] = HostResult(host, func(host))
except Exception as e:
results[idx] = HostResult(host, e)
class HostGroup:
def __init__(self, hosts: list[Host]) -> None:
self.hosts = hosts
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"HostGroup({self.hosts})"
def _run_local(
self,
cmd: str | list[str],
host: Host,
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run_local(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
)
results.append(HostResult(host, proc))
except Exception as e:
results.append(HostResult(host, e))
def _run_remote(
self,
cmd: str | list[str],
host: Host,
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if cwd is not None:
msg = "cwd is not supported for remote commands"
raise ClanError(msg)
if extra_env is None:
extra_env = {}
try:
proc = host.run(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty,
)
results.append(HostResult(host, proc))
except Exception as e:
results.append(HostResult(host, e))
def _reraise_errors(self, results: list[HostResult[Any]]) -> None:
errors = 0
for result in results:
e = result.error
if e:
cmdlog.error(
f"failed with: {e}",
extra={"command_prefix": result.host.command_prefix},
)
errors += 1
if errors > 0:
msg = f"{errors} hosts failed with an error. Check the logs above"
raise ClanError(msg)
def _run(
self,
cmd: str | list[str],
local: bool = False,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
if extra_env is None:
extra_env = {}
results: Results = []
threads = []
for host in self.hosts:
fn = self._run_local if local else self._run_remote
thread = Thread(
target=fn,
kwargs={
"results": results,
"cmd": cmd,
"host": host,
"stdout": stdout,
"stderr": stderr,
"extra_env": extra_env,
"cwd": cwd,
"check": check,
"timeout": timeout,
"verbose_ssh": verbose_ssh,
"tty": tty,
},
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if check:
self._reraise_errors(results)
return results
def run(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> Results:
"""
Command to run on the remote host via ssh
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@cwd current working directory to run the process in
@verbose_ssh: Enables verbose logging on ssh connections
@timeout: Timeout in seconds for the command to complete
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
verbose_ssh=verbose_ssh,
timeout=timeout,
tty=tty,
)
def run_local(
self,
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
) -> Results:
"""
Command to run locally for each host in the group in parallel
@cmd the command to run
@stdout if not None stdout of the command will be redirected to this file i.e. stdout=subprocss.PIPE
@stderr if not None stderr of the command will be redirected to this file i.e. stderr=subprocess.PIPE
@cwd current working directory to run the process in
@extra_env environment variables to override when running the command
@timeout: Timeout in seconds for the command to complete
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
local=True,
stdout=stdout,
stderr=stderr,
extra_env=extra_env,
cwd=cwd,
check=check,
timeout=timeout,
)
def run_function(
self, func: Callable[[Host], T], check: bool = True
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
@func the function to call
"""
threads = []
results: list[HostResult[T]] = [
HostResult(h, ClanError(f"No result set for thread {i}"))
for (i, h) in enumerate(self.hosts)
]
for i, host in enumerate(self.hosts):
thread = Thread(
target=_worker,
args=(func, host, results, i),
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if check:
self._reraise_errors(results)
return results
def filter(self, pred: Callable[[Host], bool]) -> "HostGroup":
"""Return a new Group with the results filtered by the predicate"""
return HostGroup(list(filter(pred, self.hosts)))

View File

@@ -418,14 +418,15 @@ def generate_vars(
)
machine.flush_caches()
except Exception as exc:
log.error(f"Failed to generate facts for {machine.name}: {exc}") # noqa
machine.error(f"Failed to generate facts: {exc}")
errors += [exc]
if len(errors) > 0:
msg = f"Failed to generate facts for {len(errors)} hosts. Check the logs above"
raise ClanError(msg) from errors[0]
if not was_regenerated:
print("All vars are already up to date")
machine.info("All vars are already up to date")
return was_regenerated

View File

@@ -1,13 +1,12 @@
import io
import logging
import os
import subprocess
import tarfile
from itertools import chain
from pathlib import Path
from typing import override
from clan_cli.cmd import run
from clan_cli.cmd import Log, run
from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell
@@ -135,8 +134,8 @@ class SecretStore(SecretStoreBase):
remote_hash = self.machine.target_host.run(
# TODO get the path to the secrets from the machine
["cat", f"{self.machine.secret_vars_upload_directory}/.pass_info"],
log=Log.STDERR,
check=False,
stdout=subprocess.PIPE,
).stdout.strip()
if not remote_hash:

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
import subprocess
import sys
import time
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
@@ -339,7 +340,16 @@ def run_vm(
) as vm,
ThreadPoolExecutor(max_workers=1) as executor,
):
future = executor.submit(handle_io, vm.process, input_bytes=None, log=Log.BOTH)
future = executor.submit(
handle_io,
vm.process,
cmdlog=log,
prefix=f"[{vm_config.machine_name}] ",
stdout=sys.stdout.buffer,
stderr=sys.stderr.buffer,
input_bytes=None,
log=Log.BOTH,
)
args: list[str] = vm.process.args # type: ignore[assignment]
if runtime_config.command is not None:

View File

@@ -1,30 +1,19 @@
import argparse
import logging
import os
import shlex
from clan_cli import create_parser
from clan_cli.custom_logger import get_callers
from clan_cli.custom_logger import print_trace
log = logging.getLogger(__name__)
def print_trace(msg: str) -> None:
trace_depth = int(os.environ.get("TRACE_DEPTH", "0"))
callers = get_callers(2, 2 + trace_depth)
if "run_no_stdout" in callers[0]:
callers = get_callers(3, 3 + trace_depth)
callers_str = "\n".join(f"{i+1}: {caller}" for i, caller in enumerate(callers))
log.debug(f"{msg} \nCallers: \n{callers_str}")
def run(args: list[str]) -> argparse.Namespace:
parser = create_parser(prog="clan")
parsed = parser.parse_args(args)
cmd = shlex.join(["clan", *args])
print_trace(f"$ {cmd}")
print_trace(f"$ {cmd}", log, "localhost")
if hasattr(parsed, "func"):
parsed.func(parsed)
return parsed

View File

@@ -58,6 +58,6 @@ def test_secrets_upload(
# the flake defines this path as the location where the sops key should be installed
sops_key = test_flake_with_core.path / "facts" / "key.txt"
# breakpoint()
assert sops_key.exists()
assert sops_key.read_text() == age_keys[0].privkey

View File

@@ -1,5 +1,4 @@
import subprocess
from clan_cli.cmd import Log
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup
@@ -8,21 +7,21 @@ hosts = HostGroup([Host("some_host")])
def test_run_environment() -> None:
p2 = hosts.run_local(
"echo $env_var", extra_env={"env_var": "true"}, stdout=subprocess.PIPE
["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR
)
assert p2[0].result.stdout == "true\n"
p3 = hosts.run_local(["env"], extra_env={"env_var": "true"}, stdout=subprocess.PIPE)
p3 = hosts.run_local(["env"], extra_env={"env_var": "true"}, log=Log.STDERR)
assert "env_var=true" in p3[0].result.stdout
def test_run_local() -> None:
hosts.run_local("echo hello")
hosts.run_local(["echo", "hello"])
def test_timeout() -> None:
try:
hosts.run_local("sleep 10", timeout=0.01)
hosts.run_local(["sleep", "10"], timeout=0.01)
except Exception:
pass
else:
@@ -32,8 +31,8 @@ def test_timeout() -> None:
def test_run_function() -> None:
def some_func(h: Host) -> bool:
p = h.run_local("echo hello", stdout=subprocess.PIPE)
return p.stdout == "hello\n"
par = h.run_local(["echo", "hello"], log=Log.STDERR)
return par.stdout == "hello\n"
res = hosts.run_function(some_func)
assert res[0].result
@@ -41,7 +40,7 @@ def test_run_function() -> None:
def test_run_exception() -> None:
try:
hosts.run_local("exit 1")
hosts.run_local(["exit 1"], shell=True)
except Exception:
pass
else:
@@ -51,7 +50,7 @@ def test_run_exception() -> None:
def test_run_function_exception() -> None:
def some_func(h: Host) -> None:
h.run_local("exit 1")
h.run_local(["exit 1"], shell=True)
try:
hosts.run_function(some_func)
@@ -63,5 +62,5 @@ def test_run_function_exception() -> None:
def test_run_local_non_shell() -> None:
p2 = hosts.run_local(["echo", "1"], stdout=subprocess.PIPE)
p2 = hosts.run_local(["echo", "1"], log=Log.STDERR)
assert p2[0].result.stdout == "1\n"

View File

@@ -1,6 +1,7 @@
import subprocess
import pytest
from clan_cli.cmd import Log
from clan_cli.errors import ClanError
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup
@@ -22,27 +23,27 @@ def test_parse_ipv6() -> None:
def test_run(host_group: HostGroup) -> None:
proc = host_group.run("echo hello", stdout=subprocess.PIPE)
proc = host_group.run_local(["echo", "hello"], log=Log.STDERR)
assert proc[0].result.stdout == "hello\n"
def test_run_environment(host_group: HostGroup) -> None:
p1 = host_group.run(
"echo $env_var", stdout=subprocess.PIPE, extra_env={"env_var": "true"}
["echo $env_var"], extra_env={"env_var": "true"}, shell=True, log=Log.STDERR
)
assert p1[0].result.stdout == "true\n"
p2 = host_group.run(["env"], stdout=subprocess.PIPE, extra_env={"env_var": "true"})
p2 = host_group.run(["env"], log=Log.STDERR, extra_env={"env_var": "true"})
assert "env_var=true" in p2[0].result.stdout
def test_run_no_shell(host_group: HostGroup) -> None:
proc = host_group.run(["echo", "$hello"], stdout=subprocess.PIPE)
proc = host_group.run(["echo", "$hello"], log=Log.STDERR)
assert proc[0].result.stdout == "$hello\n"
def test_run_function(host_group: HostGroup) -> None:
def some_func(h: Host) -> bool:
p = h.run("echo hello", stdout=subprocess.PIPE)
p = h.run(["echo", "hello"])
return p.stdout == "hello\n"
res = host_group.run_function(some_func)
@@ -51,7 +52,7 @@ def test_run_function(host_group: HostGroup) -> None:
def test_timeout(host_group: HostGroup) -> None:
try:
host_group.run_local("sleep 10", timeout=0.01)
host_group.run_local(["sleep", "10"], timeout=0.01)
except Exception:
pass
else:
@@ -60,11 +61,11 @@ def test_timeout(host_group: HostGroup) -> None:
def test_run_exception(host_group: HostGroup) -> None:
r = host_group.run("exit 1", check=False)
r = host_group.run(["exit 1"], check=False, shell=True)
assert r[0].result.returncode == 1
try:
host_group.run("exit 1")
host_group.run(["exit 1"], shell=True)
except Exception:
pass
else:
@@ -74,7 +75,7 @@ def test_run_exception(host_group: HostGroup) -> None:
def test_run_function_exception(host_group: HostGroup) -> None:
def some_func(h: Host) -> subprocess.CompletedProcess[str]:
return h.run_local("exit 1")
return h.run_local(["exit 1"], shell=True)
try:
host_group.run_function(some_func)