clan-cli: Replace HostGroup and MachineGroup with generic AsyncRuntime class. Propagate cmd prefix over thread local. Close threads on CTRL+C
This commit is contained in:
@@ -434,8 +434,8 @@ def main() -> None:
|
||||
else:
|
||||
log.error("%s", e) # noqa: TRY400
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
log.warning("Interrupted by user")
|
||||
except KeyboardInterrupt as ex:
|
||||
log.warning("Interrupted by user", exc_info=ex)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
312
pkgs/clan-cli/clan_cli/async_run.py
Normal file
312
pkgs/clan-cli/clan_cli/async_run.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import IO, Any, Generic, ParamSpec, TypeVar
|
||||
|
||||
from clan_cli.errors import ClanError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Why did we create a custom AsyncRuntime instead of using asyncio?
|
||||
#
|
||||
# The AsyncRuntime class allows us to run functions in separate threads for asynchronous
|
||||
# execution without requiring the use of the "async" keyword. By using this approach,
|
||||
# functions can gracefully handle cancellation by checking if get_async_ctx().cancel is True.
|
||||
#
|
||||
# There was some resistance to using asyncio, partly due to challenges we faced when
|
||||
# implementing it in our first web interface. Threads felt simpler and more familiar
|
||||
# for our use case.
|
||||
#
|
||||
# That said, asyncio is generally more efficient because it uses non-blocking I/O
|
||||
# and avoids issues with Python's Global Interpreter Lock (GIL), which can limit the
|
||||
# performance of threads in CPU-heavy workloads.
|
||||
#
|
||||
# Using threads works well for us because most of the time is spent waiting for commands
|
||||
# or external processes to finish, rather than performing heavy computing tasks.
|
||||
#
|
||||
# Note: Starting with Python 3.14, the GIL can be disabled to enable true parallelism.
|
||||
# However, disabling the GIL introduces a 10-40% performance cost for Python code
|
||||
# due to the overhead of additional locking.
|
||||
|
||||
# Define generics for return type and call signature
|
||||
R = TypeVar("R") # Return type of the callable
|
||||
P = ParamSpec("P") # Parameters of the callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncResult(Generic[R]):
|
||||
_result: R | Exception
|
||||
|
||||
@property
|
||||
def error(self) -> Exception | None:
|
||||
"""
|
||||
Returns an error if the callable raised an exception.
|
||||
"""
|
||||
if isinstance(self._result, Exception):
|
||||
return self._result
|
||||
return None
|
||||
|
||||
@property
|
||||
def result(self) -> R:
|
||||
"""
|
||||
Unwraps and returns the result if no exception occurred.
|
||||
Raises the exception otherwise.
|
||||
"""
|
||||
if isinstance(self._result, Exception):
|
||||
raise self._result
|
||||
return self._result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncContext:
|
||||
"""
|
||||
This class stores thread-local data.
|
||||
"""
|
||||
|
||||
prefix: str | None = None # prefix for logging
|
||||
stdout: IO[bytes] | None = None # stdout of subprocesses
|
||||
stderr: IO[bytes] | None = None # stderr of subprocesses
|
||||
cancel: bool = False # Used to signal cancellation of task
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncOpts:
|
||||
"""
|
||||
Options for the async_run function.
|
||||
"""
|
||||
|
||||
tid: str | None = None
|
||||
check: bool = True
|
||||
async_ctx: AsyncContext = field(default_factory=AsyncContext)
|
||||
|
||||
|
||||
ASYNC_CTX_THREAD_LOCAL = threading.local()
|
||||
|
||||
|
||||
def is_async_cancelled() -> bool:
|
||||
"""
|
||||
Check if the current task has been cancelled.
|
||||
"""
|
||||
return get_async_ctx().cancel
|
||||
|
||||
|
||||
def get_async_ctx() -> AsyncContext:
|
||||
"""
|
||||
Retrieve the current AsyncContext, creating a new one if none exists.
|
||||
"""
|
||||
global ASYNC_CTX_THREAD_LOCAL
|
||||
|
||||
if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"):
|
||||
ASYNC_CTX_THREAD_LOCAL.async_ctx = AsyncContext()
|
||||
return ASYNC_CTX_THREAD_LOCAL.async_ctx
|
||||
|
||||
|
||||
def set_async_ctx(ctx: AsyncContext) -> None:
|
||||
global ASYNC_CTX_THREAD_LOCAL
|
||||
ASYNC_CTX_THREAD_LOCAL.async_ctx = ctx
|
||||
|
||||
|
||||
class AsyncThread(threading.Thread, Generic[P, R]):
|
||||
function: Callable[P, R]
|
||||
args: Any
|
||||
kwargs: Any
|
||||
result: AsyncResult[R] | None
|
||||
finished: bool
|
||||
condition: threading.Condition
|
||||
async_opts: AsyncOpts
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_opts: AsyncOpts,
|
||||
condition: threading.Condition,
|
||||
function: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
A threaded wrapper for running a function asynchronously.
|
||||
"""
|
||||
super().__init__()
|
||||
self.function = function
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.result: AsyncResult[R] | None = None # Store the result or exception
|
||||
self.finished = False # Set to True after the thread finishes execution
|
||||
self.condition = condition # Shared condition variable
|
||||
self.async_opts = async_opts
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run the function in a separate thread.
|
||||
"""
|
||||
try:
|
||||
set_async_ctx(self.async_opts.async_ctx)
|
||||
self.result = AsyncResult(_result=self.function(*self.args, **self.kwargs))
|
||||
except Exception as ex:
|
||||
self.result = AsyncResult(_result=ex)
|
||||
finally:
|
||||
self.finished = True
|
||||
# Acquire the condition lock before notifying
|
||||
with self.condition:
|
||||
self.condition.notify_all() # Notify waiting threads that this thread is done
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncFuture(Generic[R]):
|
||||
_tid: str
|
||||
_runtime: "AsyncRuntime"
|
||||
|
||||
def wait(self) -> AsyncResult[R]:
|
||||
"""
|
||||
Wait for the task to finish.
|
||||
"""
|
||||
if self._tid not in self._runtime.tasks:
|
||||
msg = f"No task with the name '{self._tid}' exists."
|
||||
raise ClanError(msg)
|
||||
|
||||
thread = self._runtime.tasks[self._tid]
|
||||
thread.join()
|
||||
result = self.get_result()
|
||||
if result is None:
|
||||
msg = f"Task '{self._tid}' unexpectedly returned None."
|
||||
raise ClanError(msg)
|
||||
return result
|
||||
|
||||
def get_result(self) -> AsyncResult[R] | None:
|
||||
"""
|
||||
Retrieve the result of a finished task and remove it from the task list.
|
||||
"""
|
||||
if self._tid not in self._runtime.tasks:
|
||||
msg = f"No task with the name '{self._tid}' exists."
|
||||
raise ClanError(msg)
|
||||
|
||||
thread = self._runtime.tasks[self._tid]
|
||||
|
||||
if not thread.finished:
|
||||
return None
|
||||
|
||||
# Remove the task after retrieving the result
|
||||
result = thread.result
|
||||
del self._runtime.tasks[self._tid]
|
||||
|
||||
if result is None:
|
||||
msg = f"The result for task '{self._tid}' is unexpectedly None."
|
||||
raise ClanError(msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncRuntime:
|
||||
tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict)
|
||||
condition: threading.Condition = field(default_factory=threading.Condition)
|
||||
|
||||
def async_run(
|
||||
self,
|
||||
opts: AsyncOpts | None,
|
||||
function: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> AsyncFuture[R]:
|
||||
"""
|
||||
Run the given function asynchronously in a thread with a specific name and arguments.
|
||||
The function's static typing is preserved.
|
||||
"""
|
||||
if opts is None:
|
||||
opts = AsyncOpts()
|
||||
|
||||
if opts.tid is None:
|
||||
opts.tid = uuid.uuid4().hex
|
||||
|
||||
if opts.tid in self.tasks:
|
||||
msg = f"A task with the name '{opts.tid}' is already running."
|
||||
raise ClanError(msg)
|
||||
|
||||
# Create and start the new AsyncThread
|
||||
thread = AsyncThread(opts, self.condition, function, *args, **kwargs)
|
||||
self.tasks[opts.tid] = thread
|
||||
thread.start()
|
||||
return AsyncFuture(opts.tid, self)
|
||||
|
||||
def join_all(self) -> None:
|
||||
"""
|
||||
Wait for all tasks to finish
|
||||
"""
|
||||
with self.condition:
|
||||
while any(
|
||||
not task.finished for task in self.tasks.values()
|
||||
): # Check if any tasks are still running
|
||||
self.condition.wait() # Wait until a thread signals completion
|
||||
|
||||
def check_all(self) -> None:
|
||||
"""
|
||||
Check if there where any errors
|
||||
"""
|
||||
err_count = 0
|
||||
|
||||
for name, task in self.tasks.items():
|
||||
if task.finished and task.async_opts.check:
|
||||
assert task.result is not None
|
||||
error = task.result.error
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.error(
|
||||
f"failed with error: {error}",
|
||||
extra={"command_prefix": name},
|
||||
exc_info=error,
|
||||
)
|
||||
else:
|
||||
log.error(
|
||||
f"failed with error: {error}", extra={"command_prefix": name}
|
||||
)
|
||||
err_count += 1
|
||||
|
||||
if err_count > 0:
|
||||
msg = f"{err_count} hosts failed with an error. Check the logs above"
|
||||
raise ClanError(msg)
|
||||
|
||||
def __enter__(self) -> "AsyncRuntime":
|
||||
"""
|
||||
Enter the runtime context related to this object.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> None:
|
||||
"""
|
||||
Exit the runtime context related to this object.
|
||||
Sets async_ctx.cancel to True to signal cancellation.
|
||||
"""
|
||||
for name, task in self.tasks.items():
|
||||
if not task.finished:
|
||||
task.async_opts.async_ctx.cancel = True
|
||||
log.debug(f"Canceling task {name}")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
runtime = AsyncRuntime()
|
||||
|
||||
def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def concatenate(a: str, b: str) -> str:
|
||||
time.sleep(1)
|
||||
msg = "Hello World"
|
||||
raise ClanError(msg)
|
||||
|
||||
with runtime:
|
||||
p1 = runtime.async_run(None, add, 1, 2)
|
||||
p2 = runtime.async_run(None, concatenate, "Hello ", "World")
|
||||
|
||||
add_result = p1.wait()
|
||||
print(add_result.result) # Output: 3
|
||||
concat_result = p2.wait()
|
||||
print(concat_result.error) # Output: Hello World
|
||||
@@ -17,6 +17,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
|
||||
from clan_cli.async_run import get_async_ctx, is_async_cancelled
|
||||
from clan_cli.colors import Color
|
||||
from clan_cli.custom_logger import print_trace
|
||||
from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command
|
||||
@@ -87,7 +88,8 @@ def handle_io(
|
||||
stdout_extra = {}
|
||||
stderr_extra = {}
|
||||
if prefix:
|
||||
stdout_extra["command_prefix"] = stderr_extra["command_prefix"] = prefix
|
||||
stdout_extra["command_prefix"] = prefix
|
||||
stderr_extra["command_prefix"] = prefix
|
||||
if msg_color and msg_color.stderr:
|
||||
stdout_extra["color"] = msg_color.stderr.value
|
||||
if msg_color and msg_color.stdout:
|
||||
@@ -101,6 +103,13 @@ def handle_io(
|
||||
description = prefix
|
||||
raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout)
|
||||
|
||||
# Check if the command has been cancelled
|
||||
if is_async_cancelled():
|
||||
cmdlog.warning("Command cancelled", extra=stderr_extra)
|
||||
|
||||
# Terminate process
|
||||
break
|
||||
|
||||
# Wait for data to be available
|
||||
readlist, writelist, _ = select.select(rlist, wlist, [], 0.1)
|
||||
if len(readlist) == 0 and len(writelist) == 0:
|
||||
@@ -116,7 +125,7 @@ def handle_io(
|
||||
|
||||
# If Log.STDOUT is set, log the stdout output
|
||||
if ret and log in [Log.STDOUT, Log.BOTH]:
|
||||
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
|
||||
lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n")
|
||||
for line in lines:
|
||||
cmdlog.info(line, extra=stdout_extra)
|
||||
|
||||
@@ -133,7 +142,7 @@ def handle_io(
|
||||
|
||||
# If Log.STDERR is set, log the stderr output
|
||||
if ret and log in [Log.STDERR, Log.BOTH]:
|
||||
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
|
||||
lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n")
|
||||
for line in lines:
|
||||
cmdlog.info(line, extra=stderr_extra)
|
||||
|
||||
@@ -273,6 +282,17 @@ def run(
|
||||
if options.cwd is None:
|
||||
options.cwd = Path.cwd()
|
||||
|
||||
async_ctx = get_async_ctx()
|
||||
# Fill in the options from the thread-local data
|
||||
# if they are not set in the options
|
||||
if async_ctx:
|
||||
if options.prefix is None:
|
||||
options.prefix = async_ctx.prefix
|
||||
if options.stdout is None:
|
||||
options.stdout = async_ctx.stdout
|
||||
if options.stderr is None:
|
||||
options.stderr = async_ctx.stderr
|
||||
|
||||
if options.input:
|
||||
if any(not ch.isprintable() for ch in options.input.decode("ascii", "replace")):
|
||||
filtered_input = "<<binary_blob>>"
|
||||
@@ -318,6 +338,7 @@ def run(
|
||||
stdout=options.stdout,
|
||||
stderr=options.stderr,
|
||||
)
|
||||
if not is_async_cancelled():
|
||||
process.wait()
|
||||
|
||||
global TIME_TABLE
|
||||
@@ -336,7 +357,9 @@ def run(
|
||||
)
|
||||
|
||||
if options.check and process.returncode != 0:
|
||||
raise ClanCmdError(cmd_out)
|
||||
err = ClanCmdError(cmd_out)
|
||||
err.msg = "Command has been cancelled"
|
||||
raise err
|
||||
|
||||
return cmd_out
|
||||
|
||||
@@ -359,3 +382,6 @@ def run_no_stdout(
|
||||
cmd,
|
||||
opts,
|
||||
)
|
||||
|
||||
|
||||
# type: ignore
|
||||
|
||||
@@ -59,6 +59,9 @@ class PrefixFormatter(logging.Formatter):
|
||||
|
||||
# If command_prefix is set, color the prefix with a unique color.
|
||||
elif command_prefix:
|
||||
command_prefix = command_prefix[
|
||||
:10
|
||||
] # Truncate the command prefix to 10 characters.
|
||||
prefix_color = self.hostname_colorcode(command_prefix)
|
||||
format_str = color_by_tuple(f"[{command_prefix}]", fg=prefix_color)
|
||||
format_str += color_by_tuple(" %(message)s", fg=msg_color)
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from clan_cli.ssh.host import Host
|
||||
from clan_cli.ssh.host_group import HostGroup, HostResult
|
||||
|
||||
from .machines import Machine
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MachineGroup:
|
||||
def __init__(self, machines: list[Machine]) -> None:
|
||||
self.machines = machines
|
||||
self.group = HostGroup([m.target_host for m in machines])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MachineGroup({self.group})"
|
||||
|
||||
def run_function(
|
||||
self, func: Callable[[Machine], T], check: bool = True
|
||||
) -> list[HostResult[T]]:
|
||||
"""
|
||||
Function to run for each host in the group in parallel
|
||||
|
||||
@func the function to call
|
||||
"""
|
||||
|
||||
def wrapped_func(host: Host) -> T:
|
||||
return func(host.meta["machine"])
|
||||
|
||||
return self.group.run_function(wrapped_func, check=check)
|
||||
@@ -6,6 +6,7 @@ import shlex
|
||||
import sys
|
||||
|
||||
from clan_cli.api import API
|
||||
from clan_cli.async_run import AsyncContext, AsyncOpts, AsyncRuntime, is_async_cancelled
|
||||
from clan_cli.clan_uri import FlakeId
|
||||
from clan_cli.cmd import MsgColor, RunOpts, run
|
||||
from clan_cli.colors import AnsiColor
|
||||
@@ -24,7 +25,6 @@ from clan_cli.vars.generate import generate_vars
|
||||
from clan_cli.vars.upload import upload_secret_vars
|
||||
|
||||
from .inventory import get_all_machines, get_selected_machines
|
||||
from .machine_group import MachineGroup
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -113,10 +113,10 @@ def update_machines(base_path: str, machines: list[InventoryMachine]) -> None:
|
||||
# m.override_build_host = machine.deploy.buildHost
|
||||
group_machines.append(m)
|
||||
|
||||
deploy_machine(MachineGroup(group_machines))
|
||||
deploy_machine(group_machines)
|
||||
|
||||
|
||||
def deploy_machine(machines: MachineGroup) -> None:
|
||||
def deploy_machine(machines: list[Machine]) -> None:
|
||||
"""
|
||||
Deploy to all hosts in parallel
|
||||
"""
|
||||
@@ -163,11 +163,9 @@ def deploy_machine(machines: MachineGroup) -> None:
|
||||
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
|
||||
extra_env=env,
|
||||
)
|
||||
ret = host.run(
|
||||
switch_cmd,
|
||||
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
|
||||
extra_env=env,
|
||||
)
|
||||
|
||||
if is_async_cancelled():
|
||||
return
|
||||
|
||||
# if the machine is mobile, we retry to deploy with the mobile workaround method
|
||||
is_mobile = machine.deployment.get("nixosMobileWorkaround", False)
|
||||
@@ -187,13 +185,21 @@ def deploy_machine(machines: MachineGroup) -> None:
|
||||
extra_env=env,
|
||||
)
|
||||
|
||||
if len(machines.group.hosts) > 1:
|
||||
machines.run_function(deploy)
|
||||
else:
|
||||
deploy(machines.machines[0])
|
||||
with AsyncRuntime() as runtime:
|
||||
for machine in machines:
|
||||
machine.info(f"Updating {machine.name}")
|
||||
runtime.async_run(
|
||||
AsyncOpts(
|
||||
tid=machine.name, async_ctx=AsyncContext(prefix=machine.name)
|
||||
),
|
||||
deploy,
|
||||
machine,
|
||||
)
|
||||
runtime.join_all()
|
||||
|
||||
|
||||
def update(args: argparse.Namespace) -> None:
|
||||
def update_command(args: argparse.Namespace) -> None:
|
||||
try:
|
||||
if args.flake is None:
|
||||
msg = "Could not find clan flake toplevel directory"
|
||||
raise ClanError(msg)
|
||||
@@ -241,8 +247,10 @@ def update(args: argparse.Namespace) -> None:
|
||||
machine.override_build_host = args.build_host
|
||||
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
||||
|
||||
host_group = MachineGroup(machines)
|
||||
deploy_machine(host_group)
|
||||
deploy_machine(machines)
|
||||
except KeyboardInterrupt:
|
||||
log.warning("Interrupted by user")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def register_update_parser(parser: argparse.ArgumentParser) -> None:
|
||||
@@ -272,4 +280,4 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None:
|
||||
type=str,
|
||||
help="Address of the machine to build the flake, in the format of user@host:1234.",
|
||||
)
|
||||
parser.set_defaults(func=update)
|
||||
parser.set_defaults(func=update_command)
|
||||
|
||||
@@ -49,7 +49,7 @@ def nix_add_to_gcroots(nix_path: Path, dest: Path) -> None:
|
||||
|
||||
|
||||
def nix_config() -> dict[str, Any]:
|
||||
cmd = nix_command(["show-config", "--json"])
|
||||
cmd = nix_command(["config", "show", "--json"])
|
||||
proc = run_no_stdout(cmd)
|
||||
data = json.loads(proc.stdout)
|
||||
config = {}
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any
|
||||
|
||||
from clan_cli.cmd import RunOpts
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.ssh import T
|
||||
from clan_cli.ssh.host import Host
|
||||
from clan_cli.ssh.results import HostResult, Results
|
||||
|
||||
cmdlog = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _worker(
|
||||
func: Callable[[Host], T],
|
||||
host: Host,
|
||||
results: list[HostResult[T]],
|
||||
idx: int,
|
||||
) -> None:
|
||||
try:
|
||||
results[idx] = HostResult(host, func(host))
|
||||
except Exception as e:
|
||||
results[idx] = HostResult(host, e)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HostGroup:
|
||||
hosts: list[Host]
|
||||
|
||||
def _run_local(
|
||||
self,
|
||||
*,
|
||||
cmd: list[str],
|
||||
opts: RunOpts,
|
||||
extra_env: dict[str, str] | None,
|
||||
host: Host,
|
||||
results: Results,
|
||||
verbose_ssh: bool = False,
|
||||
tty: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
proc = host.run_local(
|
||||
cmd,
|
||||
opts,
|
||||
extra_env,
|
||||
)
|
||||
results.append(HostResult(host, proc))
|
||||
except Exception as e:
|
||||
results.append(HostResult(host, e))
|
||||
|
||||
def _run_remote(
|
||||
self,
|
||||
cmd: list[str],
|
||||
opts: RunOpts,
|
||||
host: Host,
|
||||
results: Results,
|
||||
extra_env: dict[str, str] | None = None,
|
||||
verbose_ssh: bool = False,
|
||||
tty: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
proc = host.run(
|
||||
cmd,
|
||||
extra_env=extra_env,
|
||||
verbose_ssh=verbose_ssh,
|
||||
opts=opts,
|
||||
tty=tty,
|
||||
)
|
||||
results.append(HostResult(host, proc))
|
||||
except Exception as e:
|
||||
results.append(HostResult(host, e))
|
||||
|
||||
def _reraise_errors(self, results: list[HostResult[Any]]) -> None:
|
||||
errors = 0
|
||||
for result in results:
|
||||
e = result.error
|
||||
if e:
|
||||
cmdlog.error(
|
||||
f"failed with: {e}",
|
||||
extra={"command_prefix": result.host.command_prefix},
|
||||
)
|
||||
errors += 1
|
||||
if errors > 0:
|
||||
msg = f"{errors} hosts failed with an error. Check the logs above"
|
||||
raise ClanError(msg) from e
|
||||
|
||||
def _run(
|
||||
self,
|
||||
cmd: list[str],
|
||||
opts: RunOpts | None = None,
|
||||
local: bool = False,
|
||||
extra_env: dict[str, str] | None = None,
|
||||
verbose_ssh: bool = False,
|
||||
tty: bool = False,
|
||||
) -> Results:
|
||||
results: Results = []
|
||||
threads = []
|
||||
|
||||
if opts is None:
|
||||
opts = RunOpts()
|
||||
|
||||
for host in self.hosts:
|
||||
if local:
|
||||
thread = Thread(
|
||||
target=self._run_local,
|
||||
kwargs={
|
||||
"cmd": cmd,
|
||||
"opts": opts,
|
||||
"host": host,
|
||||
"results": results,
|
||||
"extra_env": extra_env,
|
||||
"verbose_ssh": verbose_ssh,
|
||||
"tty": tty,
|
||||
},
|
||||
)
|
||||
else:
|
||||
thread = Thread(
|
||||
target=self._run_remote,
|
||||
kwargs={
|
||||
"results": results,
|
||||
"cmd": cmd,
|
||||
"host": host,
|
||||
"opts": opts,
|
||||
"extra_env": extra_env,
|
||||
"verbose_ssh": verbose_ssh,
|
||||
"tty": tty,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if opts.check:
|
||||
self._reraise_errors(results)
|
||||
|
||||
return results
|
||||
|
||||
def run(
|
||||
self,
|
||||
cmd: list[str],
|
||||
opts: RunOpts | None = None,
|
||||
*,
|
||||
extra_env: dict[str, str] | None = None,
|
||||
verbose_ssh: bool = False,
|
||||
tty: bool = False,
|
||||
) -> Results:
|
||||
"""
|
||||
Command to run on the remote host via ssh
|
||||
|
||||
@return a lists of tuples containing Host and the result of the command for this Host
|
||||
"""
|
||||
return self._run(
|
||||
cmd,
|
||||
opts,
|
||||
extra_env=extra_env,
|
||||
verbose_ssh=verbose_ssh,
|
||||
tty=tty,
|
||||
)
|
||||
|
||||
def run_local(
|
||||
self,
|
||||
cmd: list[str],
|
||||
opts: RunOpts | None = None,
|
||||
*,
|
||||
extra_env: dict[str, str] | None = None,
|
||||
) -> Results:
|
||||
"""
|
||||
Command to run locally for each host in the group in parallel
|
||||
|
||||
@return a lists of tuples containing Host and the result of the command for this Host
|
||||
"""
|
||||
|
||||
return self._run(
|
||||
cmd,
|
||||
opts,
|
||||
local=True,
|
||||
extra_env=extra_env,
|
||||
)
|
||||
|
||||
def run_function(
|
||||
self, func: Callable[[Host], T], check: bool = True
|
||||
) -> list[HostResult[T]]:
|
||||
"""
|
||||
Function to run for each host in the group in parallel
|
||||
"""
|
||||
threads = []
|
||||
results: list[HostResult[T]] = [
|
||||
HostResult(h, ClanError(f"No result set for thread {i}"))
|
||||
for (i, h) in enumerate(self.hosts)
|
||||
]
|
||||
for i, host in enumerate(self.hosts):
|
||||
thread = Thread(
|
||||
target=_worker,
|
||||
args=(func, host, results, i),
|
||||
)
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
if check:
|
||||
self._reraise_errors(results)
|
||||
return results
|
||||
|
||||
def filter(self, pred: Callable[[Host], bool]) -> "HostGroup":
|
||||
"""Return a new Group with the results filtered by the predicate"""
|
||||
return HostGroup(list(filter(pred, self.hosts)))
|
||||
@@ -488,7 +488,8 @@ def generate_vars(
|
||||
raise ClanError(msg) from errors[0][1]
|
||||
|
||||
if not was_regenerated and len(machines) > 0:
|
||||
log.info("All vars are already up to date")
|
||||
for machine in machines:
|
||||
machine.info("All vars are already up to date")
|
||||
|
||||
return was_regenerated
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ pytest_plugins = [
|
||||
"sshd",
|
||||
"command",
|
||||
"ports",
|
||||
"host_group",
|
||||
"hosts",
|
||||
"runtime",
|
||||
"fixtures_flakes",
|
||||
"stdout",
|
||||
"nix_config",
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import os
|
||||
import pwd
|
||||
|
||||
import pytest
|
||||
from clan_cli.ssh.host import Host
|
||||
from clan_cli.ssh.host_group import HostGroup
|
||||
from clan_cli.ssh.host_key import HostKeyCheck
|
||||
from sshd import Sshd
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def host_group(sshd: Sshd) -> HostGroup:
|
||||
login = pwd.getpwuid(os.getuid()).pw_name
|
||||
group = HostGroup(
|
||||
[
|
||||
Host(
|
||||
"127.0.0.1",
|
||||
port=sshd.port,
|
||||
user=login,
|
||||
key=sshd.key,
|
||||
host_key_check=HostKeyCheck.NONE,
|
||||
)
|
||||
]
|
||||
)
|
||||
return group
|
||||
23
pkgs/clan-cli/tests/hosts.py
Normal file
23
pkgs/clan-cli/tests/hosts.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
import pwd
|
||||
|
||||
import pytest
|
||||
from clan_cli.ssh.host import Host
|
||||
from clan_cli.ssh.host_key import HostKeyCheck
|
||||
from sshd import Sshd
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hosts(sshd: Sshd) -> list[Host]:
|
||||
login = pwd.getpwuid(os.getuid()).pw_name
|
||||
group = [
|
||||
Host(
|
||||
"127.0.0.1",
|
||||
port=sshd.port,
|
||||
user=login,
|
||||
key=sshd.key,
|
||||
host_key_check=HostKeyCheck.NONE,
|
||||
)
|
||||
]
|
||||
|
||||
return group
|
||||
7
pkgs/clan-cli/tests/runtime.py
Normal file
7
pkgs/clan-cli/tests/runtime.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
from clan_cli.async_run import AsyncRuntime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime() -> AsyncRuntime:
|
||||
return AsyncRuntime()
|
||||
@@ -120,6 +120,7 @@ def test_all_dataclasses() -> None:
|
||||
excludes = [
|
||||
"api/__init__.py",
|
||||
"cmd.py", # We don't want the UI to have access to the cmd module anyway
|
||||
"async_run.py", # We don't want the UI to have access to the async_run module anyway
|
||||
]
|
||||
|
||||
cli_path = Path("clan_cli").resolve()
|
||||
|
||||
@@ -7,7 +7,7 @@ from clan_cli.facts.secret_modules.password_store import SecretStore
|
||||
from clan_cli.machines.facts import machine_get_fact
|
||||
from clan_cli.machines.machines import Machine
|
||||
from clan_cli.nix import nix_shell
|
||||
from clan_cli.ssh.host_group import HostGroup
|
||||
from clan_cli.ssh.host import Host
|
||||
from fixtures_flakes import ClanFlake
|
||||
from helpers import cli
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_upload_secret(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
flake: ClanFlake,
|
||||
temporary_home: Path,
|
||||
host_group: HostGroup,
|
||||
hosts: list[Host],
|
||||
) -> None:
|
||||
flake.clan_modules = [
|
||||
"root-password",
|
||||
@@ -27,7 +27,7 @@ def test_upload_secret(
|
||||
config = flake.machines["vm1"]
|
||||
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
|
||||
config["clan"]["core"]["networking"]["zerotier"]["controller"]["enable"] = True
|
||||
host = host_group.hosts[0]
|
||||
host = hosts[0]
|
||||
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
|
||||
config["clan"]["core"]["networking"]["targetHost"] = addr
|
||||
config["clan"]["user-password"]["user"] = "alice"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from clan_cli.ssh.host_group import HostGroup
|
||||
from clan_cli.ssh.host import Host
|
||||
from fixtures_flakes import FlakeForTest
|
||||
from helpers import cli
|
||||
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
def test_secrets_upload(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_flake_with_core: FlakeForTest,
|
||||
host_group: HostGroup,
|
||||
hosts: list[Host],
|
||||
age_keys: list["KeyPair"],
|
||||
) -> None:
|
||||
monkeypatch.chdir(test_flake_with_core.path)
|
||||
@@ -48,7 +48,7 @@ def test_secrets_upload(
|
||||
)
|
||||
|
||||
flake = test_flake_with_core.path.joinpath("flake.nix")
|
||||
host = host_group.hosts[0]
|
||||
host = hosts[0]
|
||||
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
|
||||
new_text = flake.read_text().replace("__CLAN_TARGET_ADDRESS__", addr)
|
||||
|
||||
|
||||
@@ -1,70 +1,49 @@
|
||||
from clan_cli.cmd import Log, RunOpts
|
||||
from clan_cli.async_run import AsyncRuntime
|
||||
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||
from clan_cli.ssh.host import Host
|
||||
from clan_cli.ssh.host_group import HostGroup
|
||||
|
||||
hosts = HostGroup([Host("some_host")])
|
||||
host = Host("some_host")
|
||||
|
||||
|
||||
def test_run_environment() -> None:
|
||||
p2 = hosts.run_local(
|
||||
def test_run_environment(runtime: AsyncRuntime) -> None:
|
||||
p2 = runtime.async_run(
|
||||
None,
|
||||
host.run_local,
|
||||
["echo $env_var"],
|
||||
RunOpts(shell=True, log=Log.STDERR),
|
||||
extra_env={"env_var": "true"},
|
||||
)
|
||||
assert p2[0].result.stdout == "true\n"
|
||||
|
||||
p3 = hosts.run_local(
|
||||
["env"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"}
|
||||
assert p2.wait().result.stdout == "true\n"
|
||||
|
||||
p3 = runtime.async_run(
|
||||
None,
|
||||
host.run_local,
|
||||
["env"],
|
||||
RunOpts(shell=True, log=Log.STDERR),
|
||||
extra_env={"env_var": "true"},
|
||||
)
|
||||
assert "env_var=true" in p3[0].result.stdout
|
||||
assert "env_var=true" in p3.wait().result.stdout
|
||||
|
||||
|
||||
def test_run_local() -> None:
|
||||
hosts.run_local(["echo", "hello"])
|
||||
def test_run_local(runtime: AsyncRuntime) -> None:
|
||||
p1 = runtime.async_run(
|
||||
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
|
||||
)
|
||||
assert p1.wait().result.stdout == "hello\n"
|
||||
|
||||
|
||||
def test_timeout() -> None:
|
||||
try:
|
||||
hosts.run_local(["sleep", "10"], RunOpts(timeout=0.01))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
msg = "should have raised TimeoutExpired"
|
||||
raise AssertionError(msg)
|
||||
def test_timeout(runtime: AsyncRuntime) -> None:
|
||||
p1 = runtime.async_run(None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01))
|
||||
error = p1.wait().error
|
||||
assert isinstance(error, ClanCmdTimeoutError)
|
||||
|
||||
|
||||
def test_run_function() -> None:
|
||||
def some_func(h: Host) -> bool:
|
||||
par = h.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
|
||||
return par.stdout == "hello\n"
|
||||
|
||||
res = hosts.run_function(some_func)
|
||||
assert res[0].result
|
||||
def test_run_exception(runtime: AsyncRuntime) -> None:
|
||||
p1 = runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
|
||||
assert p1.wait().error is not None
|
||||
|
||||
|
||||
def test_run_exception() -> None:
|
||||
try:
|
||||
hosts.run_local(["exit 1"], RunOpts(shell=True))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
msg = "should have raised Exception"
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def test_run_function_exception() -> None:
|
||||
def some_func(h: Host) -> None:
|
||||
h.run_local(["exit 1"], RunOpts(shell=True))
|
||||
|
||||
try:
|
||||
hosts.run_function(some_func)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
msg = "should have raised Exception"
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def test_run_local_non_shell() -> None:
|
||||
p2 = hosts.run_local(["echo", "1"], RunOpts(log=Log.STDERR))
|
||||
assert p2[0].result.stdout == "1\n"
|
||||
def test_run_local_non_shell(runtime: AsyncRuntime) -> None:
|
||||
p2 = runtime.async_run(None, host.run_local, ["echo", "1"], RunOpts(log=Log.STDERR))
|
||||
assert p2.wait().result.stdout == "1\n"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
from clan_cli.cmd import Log, RunOpts
|
||||
from clan_cli.async_run import AsyncRuntime
|
||||
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||
from clan_cli.errors import ClanError, CmdOut
|
||||
from clan_cli.ssh.host import Host
|
||||
from clan_cli.ssh.host_group import HostGroup
|
||||
from clan_cli.ssh.host_key import HostKeyCheck
|
||||
from clan_cli.ssh.parse import parse_deployment_address
|
||||
|
||||
@@ -20,52 +20,75 @@ def test_parse_ipv6() -> None:
|
||||
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
|
||||
|
||||
|
||||
def test_run(host_group: HostGroup) -> None:
|
||||
proc = host_group.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
|
||||
assert proc[0].result.stdout == "hello\n"
|
||||
def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
for host in hosts:
|
||||
proc = runtime.async_run(
|
||||
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
|
||||
)
|
||||
assert proc.wait().result.stdout == "hello\n"
|
||||
|
||||
|
||||
def test_run_environment(host_group: HostGroup) -> None:
|
||||
p1 = host_group.run(
|
||||
def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
for host in hosts:
|
||||
proc = runtime.async_run(
|
||||
None,
|
||||
host.run_local,
|
||||
["echo $env_var"],
|
||||
RunOpts(shell=True, log=Log.STDERR),
|
||||
extra_env={"env_var": "true"},
|
||||
)
|
||||
assert p1[0].result.stdout == "true\n"
|
||||
p2 = host_group.run(["env"], RunOpts(log=Log.STDERR), extra_env={"env_var": "true"})
|
||||
assert "env_var=true" in p2[0].result.stdout
|
||||
assert proc.wait().result.stdout == "true\n"
|
||||
|
||||
for host in hosts:
|
||||
p2 = runtime.async_run(
|
||||
None,
|
||||
host.run_local,
|
||||
["env"],
|
||||
RunOpts(log=Log.STDERR),
|
||||
extra_env={"env_var": "true"},
|
||||
)
|
||||
assert "env_var=true" in p2.wait().result.stdout
|
||||
|
||||
|
||||
def test_run_no_shell(host_group: HostGroup) -> None:
|
||||
proc = host_group.run(["echo", "$hello"], RunOpts(log=Log.STDERR))
|
||||
assert proc[0].result.stdout == "$hello\n"
|
||||
def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
for host in hosts:
|
||||
proc = runtime.async_run(
|
||||
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
|
||||
)
|
||||
assert proc.wait().result.stdout == "hello\n"
|
||||
|
||||
|
||||
def test_run_function(host_group: HostGroup) -> None:
|
||||
def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
def some_func(h: Host) -> bool:
|
||||
p = h.run(["echo", "hello"])
|
||||
return p.stdout == "hello\n"
|
||||
|
||||
res = host_group.run_function(some_func)
|
||||
assert res[0].result
|
||||
for host in hosts:
|
||||
proc = runtime.async_run(None, some_func, host)
|
||||
assert proc.wait().result
|
||||
|
||||
|
||||
def test_timeout(host_group: HostGroup) -> None:
|
||||
try:
|
||||
host_group.run_local(["sleep", "10"], RunOpts(timeout=0.01))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
msg = "should have raised TimeoutExpired"
|
||||
raise AssertionError(msg)
|
||||
def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
for host in hosts:
|
||||
proc = runtime.async_run(
|
||||
None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01)
|
||||
)
|
||||
error = proc.wait().error
|
||||
assert isinstance(error, ClanCmdTimeoutError)
|
||||
|
||||
|
||||
def test_run_exception(host_group: HostGroup) -> None:
|
||||
r = host_group.run(["exit 1"], RunOpts(check=False, shell=True))
|
||||
assert r[0].result.returncode == 1
|
||||
def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
for host in hosts:
|
||||
proc = runtime.async_run(
|
||||
None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False)
|
||||
)
|
||||
assert proc.wait().result.returncode == 1
|
||||
|
||||
try:
|
||||
host_group.run(["exit 1"], RunOpts(shell=True))
|
||||
for host in hosts:
|
||||
runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
|
||||
runtime.join_all()
|
||||
runtime.check_all()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
@@ -73,12 +96,15 @@ def test_run_exception(host_group: HostGroup) -> None:
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def test_run_function_exception(host_group: HostGroup) -> None:
|
||||
def test_run_function_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||
def some_func(h: Host) -> CmdOut:
|
||||
return h.run_local(["exit 1"], RunOpts(shell=True))
|
||||
|
||||
try:
|
||||
host_group.run_function(some_func)
|
||||
for host in hosts:
|
||||
runtime.async_run(None, some_func, host)
|
||||
runtime.join_all()
|
||||
runtime.check_all()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user